Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/vision/llms.txt

Use this file to discover all available pages before exploring further.

Optical flow estimation predicts the per-pixel 2D displacement field between two consecutive video frames, representing where each pixel moved. TorchVision provides RAFT (Recurrent All-Pairs Field Transforms, Teed & Deng 2020) — a state-of-the-art recurrent architecture that iteratively refines flow predictions using a 4D all-pairs cost volume. Two model variants are available: a full-capacity raft_large and a lightweight raft_small.
All optical flow models live in torchvision.models.optical_flow. RAFT outputs a list of flow tensors — one per recurrent iteration — with the final prediction at list_of_flows[-1].

Models

RAFT Large

Full RAFT architecture. 5.26 M parameters, 211 GFLOPs. Multiple weight checkpoints covering different fine-tuning stages (FlyingChairs + FlyingThings3D, Sintel, KITTI).

RAFT Small

Compact RAFT variant. 0.99 M parameters — roughly 5× smaller. Suitable for latency-sensitive applications. Weights available for FlyingChairs + FlyingThings3D training.

Pretrained Weights

RAFT follows a multi-stage training curriculum. The weight enum names encode the datasets used:
  • C = FlyingChairs
  • T = FlyingThings3D
  • S = Sintel
  • K = KITTI
  • H = HD1K

Raft_Large_Weights

Enum keyTrained onSintel-Train Clean EPESintel-Train Final EPENotes
C_T_V1FlyingChairs + FlyingThings3D (ported)1.442.79Ported from original paper
C_T_V2FlyingChairs + FlyingThings3D1.382.72Trained from scratch
C_T_SKHT_V1+Sintel fine-tune (ported)— (Sintel-Test Clean: 1.94)— (Sintel-Test Final: 3.18)Ported
C_T_SKHT_V2 (DEFAULT)+Sintel fine-tune— (Sintel-Test Clean: 1.82)— (Sintel-Test Final: 3.07)Trained from scratch
C_T_SKHT_K_V1+Sintel +KITTI (ported)KITTI-Test fl-all: 5.10%
C_T_SKHT_K_V2+Sintel +KITTIKITTI-Test fl-all: 5.19%
EPE (End-Point Error) measures the average Euclidean distance in pixels between predicted and ground-truth flow vectors. Lower is better.

Raft_Small_Weights

Enum keyTrained onSintel-Train Clean EPESintel-Train Final EPENotes
C_T_V1FlyingChairs + FlyingThings3D (ported)2.123.28Ported from original paper
C_T_V2 (DEFAULT)FlyingChairs + FlyingThings3D1.993.28~5× fewer params than Large

Quick Start

1

Load model with pretrained weights

import torch
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights

weights = Raft_Large_Weights.DEFAULT
model = raft_large(weights=weights)
model.eval()
2

Build preprocessing transform

preprocess = weights.transforms()
# OpticalFlow transform: normalizes images to [-1, 1]
3

Load and preprocess two consecutive frames

from torchvision.io import read_image

# Two consecutive video frames
img1 = read_image("frame1.png")
img2 = read_image("frame2.png")

# Preprocess expects a pair
img1_batch, img2_batch = preprocess(img1, img2)
img1_batch = img1_batch.unsqueeze(0)
img2_batch = img2_batch.unsqueeze(0)
4

Run inference and extract final flow

with torch.no_grad():
    list_of_flows = model(img1_batch, img2_batch)

# list_of_flows[-1]: Tensor[1, 2, H, W] — final flow prediction
# flow[0, 0]: horizontal displacement (u)
# flow[0, 1]: vertical displacement (v)
final_flow = list_of_flows[-1]
print(final_flow.shape)  # torch.Size([1, 2, H, W])

Output Format

RAFT returns a list of Tensor[B, 2, H, W] tensors — one per recurrent refinement step. The two channels represent:
ChannelMeaning
flow[:, 0, :, :]Horizontal displacement u (in pixels)
flow[:, 1, :, :]Vertical displacement v (in pixels)
The list length depends on the number of refinement iterations (default: 12 for raft_large, 12 for raft_small). Use list_of_flows[-1] for the highest-quality prediction.
# Inspect individual iteration outputs
for i, flow in enumerate(list_of_flows):
    print(f"Iteration {i}: {flow.shape}, max displacement: {flow.abs().max():.2f}px")

Visualizing Flow

TorchVision provides a built-in utility to convert a flow tensor to an RGB image using the HSV color wheel convention (hue = direction, saturation = magnitude):
from torchvision.utils import flow_to_image

flow_img = flow_to_image(final_flow)  # Tensor[1, 3, H, W] uint8
Save or display the result:
from torchvision.io import write_png

# flow_to_image returns Tensor[B, 3, H, W] uint8
write_png(flow_img[0], "flow_visualization.png")

Using RAFT Small

raft_small has the same API as raft_large but with a reduced encoder and correlation volume:
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights

weights = Raft_Small_Weights.DEFAULT
model = raft_small(weights=weights)
model.eval()

preprocess = weights.transforms()

img1_batch, img2_batch = preprocess(img1, img2)
img1_batch = img1_batch.unsqueeze(0)
img2_batch = img2_batch.unsqueeze(0)

with torch.no_grad():
    list_of_flows = model(img1_batch, img2_batch)

final_flow = list_of_flows[-1]  # Tensor[1, 2, H, W]

Complete Example with Visualization

import torch
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from torchvision.io import read_image, write_png
from torchvision.utils import flow_to_image

# 1. Load model
weights = Raft_Large_Weights.DEFAULT
model = raft_large(weights=weights)
model.eval()

preprocess = weights.transforms()

# 2. Load consecutive frames (uint8 RGB tensors)
img1 = read_image("frame1.png")
img2 = read_image("frame2.png")

# 3. Preprocess and batch
img1_batch, img2_batch = preprocess(img1, img2)
img1_batch = img1_batch.unsqueeze(0)
img2_batch = img2_batch.unsqueeze(0)

# 4. Predict flow
with torch.no_grad():
    list_of_flows = model(img1_batch, img2_batch)

final_flow = list_of_flows[-1]  # Tensor[1, 2, H, W]

# 5. Visualize
flow_img = flow_to_image(final_flow)  # Tensor[1, 3, H, W] uint8
write_png(flow_img[0], "flow_output.png")

# 6. Read displacement at a specific pixel (row=100, col=200)
u = final_flow[0, 0, 100, 200].item()
v = final_flow[0, 1, 100, 200].item()
print(f"Pixel (100, 200) moved by ({u:.2f}, {v:.2f}) pixels")

Selecting the Right Weights

Sintel evaluation

Use Raft_Large_Weights.C_T_SKHT_V2 (DEFAULT). Fine-tuned on Sintel + KITTI + HD1K for best generalization.

KITTI evaluation

Use Raft_Large_Weights.C_T_SKHT_K_V2 for the lowest KITTI fl-all metric (5.19%).

Fast prototyping

Use raft_small with Raft_Small_Weights.DEFAULT — nearly 5× fewer parameters, significantly faster.

Reproducibility

Use C_T_V1 or C_T_SKHT_V1 variants (ported from the original Princeton RAFT repo) to match paper numbers exactly.
RAFT expects images preprocessed by weights.transforms(), which normalizes pixel values to [-1, 1]. Passing raw uint8 tensors directly to the model will produce incorrect results.

Build docs developers (and LLMs) love