Documentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt
Use this file to discover all available pages before exploring further.
TorchVision provides a rich set of advanced augmentation strategies that go beyond simple geometric and photometric transforms. These techniques — many backed by published research — have been shown to improve model accuracy and robustness by exposing the network to a wider distribution of training samples. All of the transforms described on this page are available in torchvision.transforms.v2 and work seamlessly with the TVTensor API, meaning you can drop them directly into an existing detection or segmentation pipeline.
AutoAugment
AutoAugment implements the policy-search method from AutoAugment: Learning Augmentation Strategies from Data. Each call randomly selects one of 25 pre-trained policies — each policy is a pair of operations with associated probabilities and magnitudes — and applies them sequentially. Three dataset-specific policies are available via AutoAugmentPolicy:
| Policy | Learned on |
|---|
AutoAugmentPolicy.IMAGENET | ImageNet classification |
AutoAugmentPolicy.CIFAR10 | CIFAR-10 classification |
AutoAugmentPolicy.SVHN | SVHN digit recognition |
import torch
import torchvision.transforms.v2 as T
# AutoAugment for ImageNet
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
AutoAugment expects torch.uint8 tensor images or PIL images in "L" or "RGB" mode. Apply it before ToDtype so the input is still in the expected integer range.
RandAugment
RandAugment (from RandAugment: Practical automated data augmentation with a reduced search space) simplifies AutoAugment by eliminating the policy search. It picks num_ops operations at random from a fixed pool and applies each at a uniform magnitude:
import torchvision.transforms.v2 as T
# Apply 2 random ops at magnitude 9 (out of 30 bins)
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.RandAugment(num_ops=2, magnitude=9),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Key constructor parameters:
| Parameter | Default | Description |
|---|
num_ops | 2 | Number of augmentation operations applied per forward pass. |
magnitude | 9 | Magnitude index into the num_magnitude_bins-step grid for all operations. |
num_magnitude_bins | 31 | Total number of magnitude steps. |
interpolation | NEAREST | Interpolation mode for geometric ops. |
fill | None | Fill value for areas outside the transformed region. |
TrivialAugmentWide
TrivialAugmentWide (from TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation) is the simplest strong baseline: it picks one random operation and applies it at a random magnitude drawn uniformly from the full range. No hyperparameter tuning is needed:
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.TrivialAugmentWide(),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
TrivialAugmentWide is an excellent starting point when you want strong
regularisation with zero tuning overhead. It consistently matches or
outperforms AutoAugment on ImageNet-scale training runs.
AugMix
AugMix (from AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty) mixes several independently augmented versions of the same image with the original, improving robustness to distribution shift:
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.AugMix(
severity=3, # base augmentation severity (1–10)
mixture_width=3, # number of augmentation chains to mix
chain_depth=-1, # depth per chain; -1 = random in [1, 3]
alpha=1.0, # Dirichlet / Beta distribution parameter
),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
| Parameter | Default | Description |
|---|
severity | 3 | Severity of each base augmentation operation (1–10). |
mixture_width | 3 | Number of independently augmented chains mixed together. |
chain_depth | -1 | Depth of each chain. Negative value means sample randomly from [1, 3]. |
alpha | 1.0 | Concentration parameter for the Dirichlet and Beta distributions. |
all_ops | True | Include brightness, contrast, color, and sharpness ops in addition to the core set. |
CutMix and MixUp
CutMix and MixUp are batch-level augmentations that interpolate between pairs of training samples. They are applied after the DataLoader assembles a batch, not to individual images.
CutMix and MixUp operate on batched inputs of shape (N, C, H, W).
They do not support PIL images, bounding boxes, keypoints, or masks. Labels
must be integer class indices (shape (N,)) or already one-hot encoded
(shape (N, num_classes)). The output labels are always soft one-hot
vectors of shape (N, num_classes) — update your loss function to accept
soft targets (e.g., use torch.nn.CrossEntropyLoss which supports
probability distributions natively since PyTorch 1.10).
import torch
import torchvision.transforms.v2 as T
# CutMix / MixUp applied to a batch
cutmix = T.CutMix(num_classes=100, alpha=1.0)
mixup = T.MixUp(num_classes=100, alpha=0.2)
cutmix_or_mixup = T.RandomChoice([cutmix, mixup])
for images, labels in dataloader:
images, labels = cutmix_or_mixup(images, labels)
# labels is now shape (N, 100) — use a soft-target-compatible loss
loss = criterion(model(images), labels)
loss.backward()
CutMix
CutMix pastes a rectangular patch from one image over another. The label mixing ratio is proportional to the area of the patch. Larger alpha makes the patch sizes more variable.
cutmix = T.CutMix(num_classes=1000, alpha=1.0)
MixUp
MixUp linearly interpolates pixel values and labels between two images using a mixing coefficient λ sampled from a Beta(alpha, alpha) distribution.
mixup = T.MixUp(num_classes=1000, alpha=0.2)
| Parameter | Description |
|---|
alpha | Beta distribution hyperparameter. Higher values → more aggressive mixing. |
num_classes | Number of classes used to construct one-hot label vectors from integer labels. Can be None if labels are already one-hot. |
labels_getter | Callable or "default" indicating how to find the labels in the input. |
RandomErasing
RandomErasing (from Random Erasing Data Augmentation) randomly selects a rectangular region in an image and fills it with a constant or random value:
import torch
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.RandomHorizontalFlip(),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.RandomErasing(
p=0.5, # probability of applying the transform
scale=(0.02, 0.33), # range of proportion of erased area
ratio=(0.3, 3.3), # range of aspect ratio of erased region
value=0.0, # fill value; use 'random' for noise
),
])
Apply RandomErasing after normalisation so the erased region is filled
with the normalised background value (typically 0.0), not the original pixel
range. Setting value='random' fills with Gaussian noise — this can
sometimes improve robustness further.
JPEG Compression Augmentation
JPEG simulates lossy JPEG compression artefacts, which helps models generalise to real-world images that have been compressed and decompressed. quality controls the JPEG quality level (1 = maximum compression / lowest quality, 100 = minimum compression / highest quality):
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
# Randomly choose JPEG quality between 50 and 95
T.JPEG(quality=(50, 95)),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Pass a single integer to always apply a fixed quality level:
T.JPEG(quality=75) # always compress at quality=75
T.JPEG(quality=(30, 100)) # random quality in [30, 100]
JPEG expects a torch.uint8 tensor on CPU with shape [..., 1 or 3, H, W].
Apply it before ToDtype so the input is still in the expected integer range.
Composing Augmentations
The v2 API includes three composition helpers that let you build complex stochastic augmentation schedules from simpler building blocks.
RandomApply
Apply a sequence of transforms with probability p. The entire sequence is either applied or skipped:
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.RandomApply(
[T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.1),
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
RandomChoice
Apply exactly one transform chosen at random from a list. Optionally supply a p list of unnormalized probabilities:
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.RandomChoice(
[T.AutoAugment(), T.RandAugment(), T.TrivialAugmentWide()],
p=[0.5, 0.3, 0.2], # AutoAugment picked 50% of the time
),
T.ToDtype(torch.float32, scale=True),
])
RandomChoice is also the recommended way to alternate between CutMix and MixUp at the batch level:
cutmix = T.CutMix(num_classes=1000, alpha=1.0)
mixup = T.MixUp(num_classes=1000, alpha=0.2)
batch_augment = T.RandomChoice([cutmix, mixup])
for images, labels in dataloader:
images, labels = batch_augment(images, labels)
RandomOrder
Apply all transforms in the list but in a randomly shuffled order on each call:
import torchvision.transforms.v2 as T
transform = T.Compose([
T.RandomResizedCrop(224, antialias=True),
T.RandomOrder([
T.ColorJitter(brightness=0.2),
T.RandomGrayscale(p=0.2),
T.RandomEqualize(),
]),
T.ToDtype(torch.float32, scale=True),
])
Putting It Together: A Full Training Pipeline
The following pipeline combines several of the strategies above for a robust ImageNet training setup:
import torch
import torchvision.transforms.v2 as T
train_transform = T.Compose([
# Spatial augmentation
T.RandomResizedCrop(224, antialias=True),
T.RandomHorizontalFlip(p=0.5),
# Advanced photometric augmentation
T.RandAugment(num_ops=2, magnitude=9),
# Optional: JPEG artefact simulation
T.RandomApply([T.JPEG(quality=(60, 100))], p=0.3),
# Erase a patch after augmentation
T.ToDtype(torch.float32, scale=True),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.RandomErasing(p=0.25),
])
# Batch-level CutMix / MixUp
cutmix_or_mixup = T.RandomChoice([
T.CutMix(num_classes=1000, alpha=1.0),
T.MixUp(num_classes=1000, alpha=0.2),
])
for images, labels in dataloader:
# Batch augmentation applied after collation
images, labels = cutmix_or_mixup(images, labels)
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()