Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt

Use this file to discover all available pages before exploring further.

torchvision.models.feature_extraction provides two utilities that let you extract activations from any intermediate layer of a model without writing custom forward hooks. The implementation rewrites the model’s computation graph using torch.fx, pruning away unused nodes so only the required subgraph is executed. This is particularly useful for building detection and segmentation heads, multi-scale feature pyramids, or transfer learning pipelines that need representations from several depths simultaneously.

API Reference

get_graph_node_names

Discovers and returns all traceable node names for a model by tracing it in both train and eval modes.
get_graph_node_names(
    model: nn.Module,
    tracer_kwargs: Optional[dict] = None,
    suppress_diff_warning: bool = False,
    concrete_args: Optional[dict] = None,
) -> tuple[list[str], list[str]]
Returns a (train_nodes, eval_nodes) tuple. Nodes are named using dot-separated paths through the module hierarchy, with a _{counter} suffix when the same operation appears more than once.
Train and eval node lists may differ when a model has branches gated by self.training. If you only use the model in eval mode, read from eval_nodes.

create_feature_extractor

Creates a new fx.GraphModule that returns specified intermediate nodes as an OrderedDict.
create_feature_extractor(
    model: nn.Module,
    return_nodes: Optional[Union[list[str], dict[str, str]]] = None,
    train_return_nodes: Optional[Union[list[str], dict[str, str]]] = None,
    eval_return_nodes: Optional[Union[list[str], dict[str, str]]] = None,
    tracer_kwargs: Optional[dict] = None,
    suppress_diff_warning: bool = False,
    concrete_args: Optional[dict] = None,
) -> fx.GraphModule
Key arguments:
ArgumentTypeDescription
return_nodeslist[str] or dict[str, str]Node names to extract. If a dict, keys are node names and values are output keys. Mutually exclusive with train_return_nodes / eval_return_nodes.
train_return_nodeslist[str] or dict[str, str]Nodes to extract in train mode only (requires eval_return_nodes too).
eval_return_nodeslist[str] or dict[str, str]Nodes to extract in eval mode only (requires train_return_nodes too).
tracer_kwargsdictPassed to NodePathTracer / torch.fx.Tracer for customising tracing behaviour. User-supplied values are merged with TorchVision defaults.
suppress_diff_warningboolSuppress the warning emitted when train and eval graphs differ. Default: False.
concrete_argsdictConcrete (non-proxy) arguments for torch.fx.Tracer.trace. API stability not guaranteed by PyTorch.

Step-by-Step Usage

1

Discover node names

Use get_graph_node_names to see all available intermediate nodes. Check the eval list for inference use cases.
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.feature_extraction import get_graph_node_names

model = resnet50(weights=ResNet50_Weights.DEFAULT)
train_nodes, eval_nodes = get_graph_node_names(model)
print(eval_nodes)  # lists all intermediate node names
Example output (truncated):
['x', 'conv1', 'bn1', 'relu', 'maxpool',
 'layer1.0.conv1', 'layer1.0.bn1', ..., 'layer1.2.relu_2',
 'layer2.0.conv1', ..., 'layer2.3.relu_2',
 'layer3.0.conv1', ..., 'layer3.5.relu_2',
 'layer4.0.conv1', ..., 'layer4.2.relu_2',
 'avgpool', 'flatten', 'fc']
2

Create a feature extractor

Pass a dict mapping node names to user-defined output keys:
from torchvision.models.feature_extraction import create_feature_extractor

return_nodes = {
    "layer1.2.relu_2": "layer1",
    "layer2.3.relu_2": "layer2",
    "layer3.5.relu_2": "layer3",
    "layer4.2.relu_2": "layer4",
}
extractor = create_feature_extractor(model, return_nodes=return_nodes)
3

Run and collect intermediate features

The extractor behaves like a regular nn.Module. Its output is an OrderedDict keyed by your chosen names.
import torch

x = torch.rand(1, 3, 224, 224)
features = extractor(x)

# features is an OrderedDict:
# features["layer1"]: Tensor[1, 256, 56, 56]
# features["layer2"]: Tensor[1, 512, 28, 28]
# features["layer3"]: Tensor[1, 1024, 14, 14]
# features["layer4"]: Tensor[1, 2048, 7, 7]
for name, feat in features.items():
    print(f"{name}: {feat.shape}")

Building a Feature Pyramid Network

Feature extractors integrate directly with torchvision.ops.FeaturePyramidNetwork for multi-scale detection heads:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops import FeaturePyramidNetwork
import torch

# Build backbone extractor
model = resnet50(weights=ResNet50_Weights.DEFAULT)
return_nodes = {
    "layer1.2.relu_2": "layer1",
    "layer2.3.relu_2": "layer2",
    "layer3.5.relu_2": "layer3",
    "layer4.2.relu_2": "layer4",
}
extractor = create_feature_extractor(model, return_nodes=return_nodes)

# Attach FPN head
fpn = FeaturePyramidNetwork(
    in_channels_list=[256, 512, 1024, 2048],
    out_channels=256,
)

# Forward pass
x = torch.rand(1, 3, 224, 224)
features = extractor(x)     # OrderedDict of four feature maps
fpn_output = fpn(features)  # OrderedDict of four 256-channel maps

for name, feat in fpn_output.items():
    print(f"{name}: {feat.shape}")
# layer1: torch.Size([1, 256, 56, 56])
# layer2: torch.Size([1, 256, 28, 28])
# layer3: torch.Size([1, 256, 14, 14])
# layer4: torch.Size([1, 256,  7,  7])

Separate Train / Eval Return Nodes

Some models have Dropout or BatchNorm branches that only exist in train mode. Use train_return_nodes and eval_return_nodes when the desired nodes differ between modes:
extractor = create_feature_extractor(
    model,
    train_return_nodes={"layer3.5.relu_2": "layer3"},
    eval_return_nodes={"layer4.2.relu_2": "layer4"},
)

# In train mode: returns {"layer3": ...}
extractor.train()
train_feats = extractor(x)

# In eval mode: returns {"layer4": ...}
extractor.eval()
eval_feats = extractor(x)

Handling Non-Traceable Models

create_feature_extractor uses torch.fx symbolic tracing. Some models call Python builtins (e.g., int(x.shape[0])) or contain dynamic control flow that breaks tracing. Use tracer_kwargs to work around these:
class ProblemModule(torch.nn.Module):
    def forward(self, x):
        n = int(x.shape[0])  # <- breaks tracing
        return x.repeat(n, 1, 1, 1)

# Make it a leaf — it won't be traced through
extractor = create_feature_extractor(
    my_model,
    return_nodes={"some_layer": "feat"},
    tracer_kwargs={"leaf_modules": [ProblemModule]},
)

leaf_modules

Pass a list of nn.Module subclasses that should be treated as atomic leaves. Their forward will not be traced — the graph will hold a reference to the whole module instead.

autowrap_functions

Pass standalone functions that should not be symbolically traced. Useful for wrapping builtins like int or len.

Tips and Caveats

create_feature_extractor modifies the model’s computation graph. The returned fx.GraphModule is a new object — the original model is unchanged but shared submodule weights are used by reference.
Use extractor.train() / extractor.eval() as normal. If you passed separate train_return_nodes / eval_return_nodes, calling train(mode) automatically switches between the two internal graphs.
By default, tracer_kwargs is pre-populated to treat all torchvision.ops modules as leaves (e.g., RoIAlign, DeformConv2d). This prevents tracing into their CUDA-specific internals. User-supplied tracer_kwargs are merged with these defaults.

Build docs developers (and LLMs) love