Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt

Use this file to discover all available pages before exploring further.

TorchVision provides a rich set of advanced augmentation strategies that go beyond simple geometric and photometric transforms. These techniques — many backed by published research — have been shown to improve model accuracy and robustness by exposing the network to a wider distribution of training samples. All of the transforms described on this page are available in torchvision.transforms.v2 and work seamlessly with the TVTensor API, meaning you can drop them directly into an existing detection or segmentation pipeline.

AutoAugment

AutoAugment implements the policy-search method from AutoAugment: Learning Augmentation Strategies from Data. Each call randomly selects one of 25 pre-trained policies — each policy is a pair of operations with associated probabilities and magnitudes — and applies them sequentially. Three dataset-specific policies are available via AutoAugmentPolicy:
PolicyLearned on
AutoAugmentPolicy.IMAGENETImageNet classification
AutoAugmentPolicy.CIFAR10CIFAR-10 classification
AutoAugmentPolicy.SVHNSVHN digit recognition
import torch
import torchvision.transforms.v2 as T

# AutoAugment for ImageNet
transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
AutoAugment expects torch.uint8 tensor images or PIL images in "L" or "RGB" mode. Apply it before ToDtype so the input is still in the expected integer range.

RandAugment

RandAugment (from RandAugment: Practical automated data augmentation with a reduced search space) simplifies AutoAugment by eliminating the policy search. It picks num_ops operations at random from a fixed pool and applies each at a uniform magnitude:
import torchvision.transforms.v2 as T

# Apply 2 random ops at magnitude 9 (out of 30 bins)
transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.RandAugment(num_ops=2, magnitude=9),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Key constructor parameters:
ParameterDefaultDescription
num_ops2Number of augmentation operations applied per forward pass.
magnitude9Magnitude index into the num_magnitude_bins-step grid for all operations.
num_magnitude_bins31Total number of magnitude steps.
interpolationNEARESTInterpolation mode for geometric ops.
fillNoneFill value for areas outside the transformed region.

TrivialAugmentWide

TrivialAugmentWide (from TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation) is the simplest strong baseline: it picks one random operation and applies it at a random magnitude drawn uniformly from the full range. No hyperparameter tuning is needed:
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.TrivialAugmentWide(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
TrivialAugmentWide is an excellent starting point when you want strong regularisation with zero tuning overhead. It consistently matches or outperforms AutoAugment on ImageNet-scale training runs.

AugMix

AugMix (from AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty) mixes several independently augmented versions of the same image with the original, improving robustness to distribution shift:
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.AugMix(
        severity=3,        # base augmentation severity (1–10)
        mixture_width=3,   # number of augmentation chains to mix
        chain_depth=-1,    # depth per chain; -1 = random in [1, 3]
        alpha=1.0,         # Dirichlet / Beta distribution parameter
    ),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
ParameterDefaultDescription
severity3Severity of each base augmentation operation (1–10).
mixture_width3Number of independently augmented chains mixed together.
chain_depth-1Depth of each chain. Negative value means sample randomly from [1, 3].
alpha1.0Concentration parameter for the Dirichlet and Beta distributions.
all_opsTrueInclude brightness, contrast, color, and sharpness ops in addition to the core set.

CutMix and MixUp

CutMix and MixUp are batch-level augmentations that interpolate between pairs of training samples. They are applied after the DataLoader assembles a batch, not to individual images.
CutMix and MixUp operate on batched inputs of shape (N, C, H, W). They do not support PIL images, bounding boxes, keypoints, or masks. Labels must be integer class indices (shape (N,)) or already one-hot encoded (shape (N, num_classes)). The output labels are always soft one-hot vectors of shape (N, num_classes) — update your loss function to accept soft targets (e.g., use torch.nn.CrossEntropyLoss which supports probability distributions natively since PyTorch 1.10).
import torch
import torchvision.transforms.v2 as T

# CutMix / MixUp applied to a batch
cutmix = T.CutMix(num_classes=100, alpha=1.0)
mixup  = T.MixUp(num_classes=100, alpha=0.2)
cutmix_or_mixup = T.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
    images, labels = cutmix_or_mixup(images, labels)
    # labels is now shape (N, 100) — use a soft-target-compatible loss
    loss = criterion(model(images), labels)
    loss.backward()

CutMix

CutMix pastes a rectangular patch from one image over another. The label mixing ratio is proportional to the area of the patch. Larger alpha makes the patch sizes more variable.
cutmix = T.CutMix(num_classes=1000, alpha=1.0)

MixUp

MixUp linearly interpolates pixel values and labels between two images using a mixing coefficient λ sampled from a Beta(alpha, alpha) distribution.
mixup = T.MixUp(num_classes=1000, alpha=0.2)
ParameterDescription
alphaBeta distribution hyperparameter. Higher values → more aggressive mixing.
num_classesNumber of classes used to construct one-hot label vectors from integer labels. Can be None if labels are already one-hot.
labels_getterCallable or "default" indicating how to find the labels in the input.

RandomErasing

RandomErasing (from Random Erasing Data Augmentation) randomly selects a rectangular region in an image and fills it with a constant or random value:
import torch
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.RandomHorizontalFlip(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    T.RandomErasing(
        p=0.5,              # probability of applying the transform
        scale=(0.02, 0.33), # range of proportion of erased area
        ratio=(0.3, 3.3),   # range of aspect ratio of erased region
        value=0.0,          # fill value; use 'random' for noise
    ),
])
Apply RandomErasing after normalisation so the erased region is filled with the normalised background value (typically 0.0), not the original pixel range. Setting value='random' fills with Gaussian noise — this can sometimes improve robustness further.

JPEG Compression Augmentation

JPEG simulates lossy JPEG compression artefacts, which helps models generalise to real-world images that have been compressed and decompressed. quality controls the JPEG quality level (1 = maximum compression / lowest quality, 100 = minimum compression / highest quality):
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    # Randomly choose JPEG quality between 50 and 95
    T.JPEG(quality=(50, 95)),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Pass a single integer to always apply a fixed quality level:
T.JPEG(quality=75)   # always compress at quality=75
T.JPEG(quality=(30, 100))  # random quality in [30, 100]
JPEG expects a torch.uint8 tensor on CPU with shape [..., 1 or 3, H, W]. Apply it before ToDtype so the input is still in the expected integer range.

Composing Augmentations

The v2 API includes three composition helpers that let you build complex stochastic augmentation schedules from simpler building blocks.

RandomApply

Apply a sequence of transforms with probability p. The entire sequence is either applied or skipped:
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.RandomApply(
        [T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
        p=0.8,
    ),
    T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.1),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

RandomChoice

Apply exactly one transform chosen at random from a list. Optionally supply a p list of unnormalized probabilities:
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.RandomChoice(
        [T.AutoAugment(), T.RandAugment(), T.TrivialAugmentWide()],
        p=[0.5, 0.3, 0.2],   # AutoAugment picked 50% of the time
    ),
    T.ToDtype(torch.float32, scale=True),
])
RandomChoice is also the recommended way to alternate between CutMix and MixUp at the batch level:
cutmix = T.CutMix(num_classes=1000, alpha=1.0)
mixup  = T.MixUp(num_classes=1000, alpha=0.2)
batch_augment = T.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
    images, labels = batch_augment(images, labels)

RandomOrder

Apply all transforms in the list but in a randomly shuffled order on each call:
import torchvision.transforms.v2 as T

transform = T.Compose([
    T.RandomResizedCrop(224, antialias=True),
    T.RandomOrder([
        T.ColorJitter(brightness=0.2),
        T.RandomGrayscale(p=0.2),
        T.RandomEqualize(),
    ]),
    T.ToDtype(torch.float32, scale=True),
])

Putting It Together: A Full Training Pipeline

The following pipeline combines several of the strategies above for a robust ImageNet training setup:
import torch
import torchvision.transforms.v2 as T

train_transform = T.Compose([
    # Spatial augmentation
    T.RandomResizedCrop(224, antialias=True),
    T.RandomHorizontalFlip(p=0.5),
    # Advanced photometric augmentation
    T.RandAugment(num_ops=2, magnitude=9),
    # Optional: JPEG artefact simulation
    T.RandomApply([T.JPEG(quality=(60, 100))], p=0.3),
    # Erase a patch after augmentation
    T.ToDtype(torch.float32, scale=True),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    T.RandomErasing(p=0.25),
])

# Batch-level CutMix / MixUp
cutmix_or_mixup = T.RandomChoice([
    T.CutMix(num_classes=1000, alpha=1.0),
    T.MixUp(num_classes=1000, alpha=0.2),
])

for images, labels in dataloader:
    # Batch augmentation applied after collation
    images, labels = cutmix_or_mixup(images, labels)
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()

Build docs developers (and LLMs) love