TorchVision’s Multi-Weight Support API, introduced in v0.13, makes every pre-trained checkpoint a self-contained object that carries its download URL, the exact preprocessing pipeline used during training, and a rich metadata dictionary — all in one place. This means you can never accidentally mismatch a model checkpoint with the wrong image normalization parameters.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt
Use this file to discover all available pages before exploring further.
Core Data Structures
Weights dataclass
Every individual checkpoint is represented as a Weights dataclass instance with three fields:
| Field | Type | Description |
|---|---|---|
url | str | Direct download URL for the serialised state_dict. Used internally by weights.get_state_dict(). |
transforms | Callable | A constructor (not an already-built object) for the preprocessing transforms. Call weights.transforms() to instantiate. |
meta | dict | Arbitrary metadata — accuracy metrics, parameter count, FLOP count, training recipe URL, output class list, and more. |
transforms is stored as a constructor (typically a functools.partial) rather than a live object to avoid holding unnecessary memory until the transforms are actually needed.WeightsEnum
WeightsEnum is a subclass of Python’s enum.Enum where every member wraps a Weights instance. Each model family has its own WeightsEnum subclass, named <ModelName>_Weights.
IMAGENET1K_V1, IMAGENET1K_V2, … reflects successive training recipe improvements on the same dataset. DEFAULT is a stable alias that always points to the best-performing entry — it may advance across TorchVision releases.
Loading a Model with Weights
Pass anyWeightsEnum member (or a string shorthand) to the model builder’s weights= argument:
weights=None to get a randomly-initialised model with no checkpoint download:
weights.transforms() — Preprocessing Pipeline
Each WeightsEnum member exposes a transforms property that returns a callable constructor. Invoke it (with no arguments) to get the fully-built preprocessing pipeline for that checkpoint:
nn.Module-compatible transform that accepts a torch.Tensor image (uint8 or float32, shape [C, H, W]) and returns a normalised float32 tensor ready for the model.
Different weight versions for the same model may use different resize and crop sizes. For example,
ResNet50_Weights.IMAGENET1K_V1 uses crop_size=224 (no resize override), while IMAGENET1K_V2 resizes to 232 before cropping to 224 (a FixRes-style technique). Always retrieve transforms from the specific weights object you loaded.weights.meta — Checkpoint Metadata
The meta dictionary carries all information about how the checkpoint was produced and what it expects at runtime.
Common meta keys
| Key | Type | Description |
|---|---|---|
categories | list[str] | Output class names in label-index order. Use weights.meta["categories"][class_id] to decode predictions. |
_metrics | dict | Nested dict of {dataset: {metric: value}}, e.g. {"ImageNet-1K": {"acc@1": 80.858, "acc@5": 95.434}}. |
num_params | int | Total number of trainable parameters in the model. |
min_size | tuple[int, int] | Minimum (height, width) the model can accept. |
_ops | float | Approximate GFLOPs for a single forward pass. |
_file_size | float | Checkpoint file size in MB. |
recipe | str | URL to the training recipe or pull request. |
_docs | str | Human-readable description of what makes this checkpoint different. |
Full Inference Example
get_model_weights — Resolve WeightsEnum by Model Name
get_model_weights returns the WeightsEnum class for a given model, identified either by its string name or by the builder function itself. This is useful when you want to enumerate available checkpoints without hard-coding the enum class name.
get_weight — Load a Specific Checkpoint by Dot-Notation String
get_weight resolves a weights instance from a fully-qualified dot-notation string of the form "<WeightsEnumClass>.<MEMBER>". This is the canonical way to specify a checkpoint in configuration files or command-line arguments.
get_weight searches across all registered model submodules (torchvision.models, torchvision.models.detection, torchvision.models.segmentation, etc.), so quantized and detection weights are all accessible through the same function.Programmatic Weight Inspection
Downloading Weights Manually
Weights are fetched automatically on first use viatorch.hub. To pre-download them (e.g. in a Docker build step), call get_state_dict explicitly:
TORCH_HOME to control the cache location:
API Reference Summary
| Function / Class | Where to import | Description |
|---|---|---|
Weights | torchvision.models | Dataclass holding url, transforms, and meta for one checkpoint. |
WeightsEnum | torchvision.models | Base Enum class for all model weight enumerations. |
ResNet50_Weights | torchvision.models | Example per-model weights enum; one exists for every supported architecture. |
get_model_weights(name) | torchvision.models | Returns the WeightsEnum subclass for a model given its string name or builder function. |
get_weight(name) | torchvision.models | Returns a specific WeightsEnum member from a dot-notation string like "ResNet50_Weights.DEFAULT". |
get_model(name, **cfg) | torchvision.models | Instantiates a model by registered string name, passing weights= and other kwargs through. |
list_models(...) | torchvision.models | Returns a sorted list of all registered model names, optionally filtered by module or glob pattern. |