Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt

Use this file to discover all available pages before exploring further.

TorchVision’s Multi-Weight Support API, introduced in v0.13, makes every pre-trained checkpoint a self-contained object that carries its download URL, the exact preprocessing pipeline used during training, and a rich metadata dictionary — all in one place. This means you can never accidentally mismatch a model checkpoint with the wrong image normalization parameters.

Core Data Structures

Weights dataclass

Every individual checkpoint is represented as a Weights dataclass instance with three fields:
from dataclasses import dataclass
from typing import Any, Callable

@dataclass
class Weights:
    url: str            # HTTPS URL to the .pth checkpoint file
    transforms: Callable  # Constructor for the preprocessing pipeline
    meta: dict[str, Any]  # Accuracy, parameter count, recipe link, etc.
FieldTypeDescription
urlstrDirect download URL for the serialised state_dict. Used internally by weights.get_state_dict().
transformsCallableA constructor (not an already-built object) for the preprocessing transforms. Call weights.transforms() to instantiate.
metadictArbitrary metadata — accuracy metrics, parameter count, FLOP count, training recipe URL, output class list, and more.
transforms is stored as a constructor (typically a functools.partial) rather than a live object to avoid holding unnecessary memory until the transforms are actually needed.

WeightsEnum

WeightsEnum is a subclass of Python’s enum.Enum where every member wraps a Weights instance. Each model family has its own WeightsEnum subclass, named <ModelName>_Weights.
from torchvision.models import ResNet50_Weights

# Inspect all available checkpoints for ResNet-50
print(list(ResNet50_Weights))
# [<ResNet50_Weights.IMAGENET1K_V1>, <ResNet50_Weights.IMAGENET1K_V2>]

# DEFAULT is always an alias for the best available checkpoint
print(ResNet50_Weights.DEFAULT)        # ResNet50_Weights.IMAGENET1K_V2
print(ResNet50_Weights.IMAGENET1K_V1)  # ResNet50_Weights.IMAGENET1K_V1
print(ResNet50_Weights.IMAGENET1K_V2)  # ResNet50_Weights.IMAGENET1K_V2
The naming convention IMAGENET1K_V1, IMAGENET1K_V2, … reflects successive training recipe improvements on the same dataset. DEFAULT is a stable alias that always points to the best-performing entry — it may advance across TorchVision releases.

Loading a Model with Weights

Pass any WeightsEnum member (or a string shorthand) to the model builder’s weights= argument:
from torchvision.models import resnet50, ResNet50_Weights

# Load the improved V2 checkpoint (80.9% top-1 on ImageNet-1K)
weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=weights)
model.eval()
Always call model.eval() immediately after loading pre-trained weights. This switches batch-norm and dropout layers to inference mode and is required for reproducible, correct predictions.
String shorthands are also accepted — useful for command-line or config-file driven workflows:
# These three calls are all equivalent
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
resnet50(weights="IMAGENET1K_V2")
resnet50(weights="DEFAULT")
Pass weights=None to get a randomly-initialised model with no checkpoint download:
# Architecture only — no pretrained weights
model = resnet50(weights=None)
The legacy pretrained=True boolean was deprecated in v0.13 and removed in v0.15. Update any existing code to use weights=ModelName_Weights.DEFAULT (pretrained) or weights=None (random init).

weights.transforms() — Preprocessing Pipeline

Each WeightsEnum member exposes a transforms property that returns a callable constructor. Invoke it (with no arguments) to get the fully-built preprocessing pipeline for that checkpoint:
from torchvision.models import ResNet50_Weights

weights = ResNet50_Weights.IMAGENET1K_V2
preprocess = weights.transforms()

print(preprocess)
# ImageClassification(
#     crop_size=[224]
#     resize_size=[232]
#     mean=[0.485, 0.456, 0.406]
#     std=[0.229, 0.224, 0.225]
#     interpolation=InterpolationMode.BILINEAR
# )
The returned object is an nn.Module-compatible transform that accepts a torch.Tensor image (uint8 or float32, shape [C, H, W]) and returns a normalised float32 tensor ready for the model.
from torchvision.io import decode_image

img = decode_image("photo.jpg")           # uint8 Tensor, shape [C, H, W]
batch = preprocess(img).unsqueeze(0)      # float32 Tensor, shape [1, C, H, W]

prediction = model(batch).squeeze(0).softmax(0)
Different weight versions for the same model may use different resize and crop sizes. For example, ResNet50_Weights.IMAGENET1K_V1 uses crop_size=224 (no resize override), while IMAGENET1K_V2 resizes to 232 before cropping to 224 (a FixRes-style technique). Always retrieve transforms from the specific weights object you loaded.

weights.meta — Checkpoint Metadata

The meta dictionary carries all information about how the checkpoint was produced and what it expects at runtime.
from torchvision.models import ResNet50_Weights

weights = ResNet50_Weights.IMAGENET1K_V2

# Output class list (1000 ImageNet-1K categories)
print(weights.meta["categories"][:5])
# ['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead']

# Accuracy metrics on the evaluation dataset
print(weights.meta["_metrics"])
# {'ImageNet-1K': {'acc@1': 80.858, 'acc@5': 95.434}}

# Total trainable parameters
print(weights.meta["num_params"])
# 25557032

# Minimum accepted input spatial size
print(weights.meta["min_size"])
# (1, 1)

# Link to the training recipe / PR
print(weights.meta["recipe"])
# 'https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621'

Common meta keys

KeyTypeDescription
categorieslist[str]Output class names in label-index order. Use weights.meta["categories"][class_id] to decode predictions.
_metricsdictNested dict of {dataset: {metric: value}}, e.g. {"ImageNet-1K": {"acc@1": 80.858, "acc@5": 95.434}}.
num_paramsintTotal number of trainable parameters in the model.
min_sizetuple[int, int]Minimum (height, width) the model can accept.
_opsfloatApproximate GFLOPs for a single forward pass.
_file_sizefloatCheckpoint file size in MB.
recipestrURL to the training recipe or pull request.
_docsstrHuman-readable description of what makes this checkpoint different.

Full Inference Example

from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights

# 1. Resolve weights and build the model
weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=weights)
model.eval()

# 2. Build the preprocessing pipeline tied to this checkpoint
preprocess = weights.transforms()
print(preprocess)
# ImageClassification(crop_size=[224], resize_size=[232], ...)

# 3. Inspect metadata before running
print(weights.meta["categories"][:5])   # ['tench', 'goldfish', ...]
print(weights.meta["_metrics"])          # {'ImageNet-1K': {'acc@1': 80.858, 'acc@5': 95.434}}
print(weights.meta["num_params"])        # 25557032

# 4. Run inference
img = decode_image("photo.jpg")
batch = preprocess(img).unsqueeze(0)

prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category = weights.meta["categories"][class_id]
print(f"{category}: {100 * score:.1f}%")

get_model_weights — Resolve WeightsEnum by Model Name

get_model_weights returns the WeightsEnum class for a given model, identified either by its string name or by the builder function itself. This is useful when you want to enumerate available checkpoints without hard-coding the enum class name.
from torchvision.models import get_model_weights

# By string name
weights_enum = get_model_weights("resnet50")
print(weights_enum)         # <enum 'ResNet50_Weights'>

# Iterate all available checkpoints
for w in weights_enum:
    acc = w.meta["_metrics"]["ImageNet-1K"]["acc@1"]
    print(f"{w}  —  acc@1: {acc}%")
# ResNet50_Weights.IMAGENET1K_V1  —  acc@1: 76.13%
# ResNet50_Weights.IMAGENET1K_V2  —  acc@1: 80.858%

# By builder function (also accepted)
import torchvision.models as models
weights_enum2 = get_model_weights(models.resnet50)
assert weights_enum is weights_enum2

get_weight — Load a Specific Checkpoint by Dot-Notation String

get_weight resolves a weights instance from a fully-qualified dot-notation string of the form "<WeightsEnumClass>.<MEMBER>". This is the canonical way to specify a checkpoint in configuration files or command-line arguments.
from torchvision.models import get_weight, get_model

# Resolve by fully-qualified string
weights = get_weight("ResNet50_Weights.DEFAULT")
print(weights)   # ResNet50_Weights.IMAGENET1K_V2

# Works with any registered WeightsEnum across all submodules
quant_weights = get_weight("ResNet50_QuantizedWeights.DEFAULT")

# Combine with get_model for a fully string-driven pipeline
weights = get_weight("ResNet50_Weights.IMAGENET1K_V2")
model = get_model("resnet50", weights=weights)
preprocess = weights.transforms()
model.eval()
get_weight searches across all registered model submodules (torchvision.models, torchvision.models.detection, torchvision.models.segmentation, etc.), so quantized and detection weights are all accessible through the same function.

Programmatic Weight Inspection

from torchvision.models import get_model_weights

for weights in get_model_weights("resnet50"):
    m = weights.meta
    metrics = m["_metrics"].get("ImageNet-1K", {})
    print(
        f"{weights}\n"
        f"  acc@1={metrics.get('acc@1')}  "
        f"acc@5={metrics.get('acc@5')}  "
        f"params={m['num_params']:,}  "
        f"size={m['_file_size']} MB"
    )

Downloading Weights Manually

Weights are fetched automatically on first use via torch.hub. To pre-download them (e.g. in a Docker build step), call get_state_dict explicitly:
from torchvision.models import ResNet50_Weights

# Downloads to $TORCH_HOME/hub/checkpoints/ (or default cache)
state_dict = ResNet50_Weights.DEFAULT.get_state_dict(progress=True)
Set TORCH_HOME to control the cache location:
export TORCH_HOME=/mnt/model-cache
python download_weights.py

API Reference Summary

Function / ClassWhere to importDescription
Weightstorchvision.modelsDataclass holding url, transforms, and meta for one checkpoint.
WeightsEnumtorchvision.modelsBase Enum class for all model weight enumerations.
ResNet50_Weightstorchvision.modelsExample per-model weights enum; one exists for every supported architecture.
get_model_weights(name)torchvision.modelsReturns the WeightsEnum subclass for a model given its string name or builder function.
get_weight(name)torchvision.modelsReturns a specific WeightsEnum member from a dot-notation string like "ResNet50_Weights.DEFAULT".
get_model(name, **cfg)torchvision.modelsInstantiates a model by registered string name, passing weights= and other kwargs through.
list_models(...)torchvision.modelsReturns a sorted list of all registered model names, optionally filtered by module or glob pattern.

Build docs developers (and LLMs) love