Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL is designed to scale from a quick local experiment to a production distributed run without changing the data model. Every major component — collectors, replay buffers, value estimators, recurrent modules, and loss functions — has explicit performance knobs. This guide walks through the most impactful options: vectorised return computation, torch.compile integration, async multi-process data collection, CUDA-accelerated priority trees, memory-mapped offline storage, and Triton-backed recurrent RL.

Vectorised Return and Advantage Computation

TorchRL’s value estimation functions come in two flavours: a reference Python implementation and a vectorised implementation that dispatches to a fused C++ or scan-based kernel when gamma and lambda are scalars.
from torchrl.objectives.value.functional import (
    vec_generalized_advantage_estimate,  # vectorised GAE
    vec_td_lambda_return_estimate,        # vectorised TD(λ) returns
    vec_td_lambda_advantage_estimate,     # vectorised TD(λ) advantages
    vec_td1_return_estimate,              # vectorised TD(1) returns
    vec_td1_advantage_estimate,           # vectorised TD(1) advantages
)

# Vectorised GAE — preferred for scalar gamma/lambda
advantage, value_target = vec_generalized_advantage_estimate(
    gamma=0.99,
    lmbda=0.95,
    state_value=state_values,          # shape [B, T, 1]
    next_state_value=next_values,      # shape [B, T, 1]
    reward=rewards,                    # shape [B, T, 1]
    done=dones,                        # shape [B, T, 1]
    terminated=terminated,             # shape [B, T, 1]
    time_dim=-2,
)
The high-level GAE estimator class wraps this function and integrates with any TensorDictModule-based value network:
from torchrl.objectives.value import GAE

gae = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=value_net,
    average_gae=False,
)

# Returns tensordict augmented with "advantage" and "value_target" keys
batch = gae(batch)
For multi-agent settings, MultiAgentGAE handles a batch dimension that includes an agent axis without reducing over it:
from torchrl.objectives.value import MultiAgentGAE

ma_gae = MultiAgentGAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=centralised_critic,
    average_gae=False,
)

torch.compile Compatibility

TorchRL components are torch.compile-aware. The vectorised GAE implementation detects compile tracing via torch.compiler.is_dynamo_compiling() and skips data-dependent truncation optimisations that would produce graph breaks.

Compiling Loss Modules

import torch
from torchrl.objectives import PPOLoss

loss_fn = PPOLoss(actor_network=actor, critic_network=critic)

# Compile the full forward pass
compiled_loss = torch.compile(loss_fn, mode="reduce-overhead")

# Training loop — first call triggers compilation
for batch in dataloader:
    loss_td = compiled_loss(batch)
    loss_td["loss_objective"].backward()

compile_with_warmup

TorchRL ships compile_with_warmup — a torch.compile wrapper that runs the model eagerly for a fixed number of warm-up calls before switching to the compiled version. This prevents slow first-batch latency from affecting logged metrics:
from torchrl._utils import compile_with_warmup

# Compile after 3 warm-up calls
compiled_value_fn = compile_with_warmup(value_net, warmup=3, mode="reduce-overhead")

for i, batch in enumerate(dataloader):
    # First 3 calls: eager; call 4 onward: compiled
    values = compiled_value_fn(batch)

RSSMRollout with torch.compile

The RSSMRollout module supports two compile-friendly modes:
from torchrl.modules.models.model_based import RSSMRollout

# scan mode — fewer graph breaks than the Python loop
rssm_rollout = RSSMRollout(
    rssm_prior=prior_module,
    rssm_posterior=posterior_module,
    use_scan=True,          # uses torch._higher_order_ops.scan
    compile_step=False,     # OR compile individual step functions
)

# Compile the full rollout
compiled_rollout = torch.compile(rssm_rollout, mode="reduce-overhead")
use_scan=True is more torch.compile-friendly because the Python for-loop in the default mode creates one graph per timestep. With use_scan, the time dimension is handled inside a single operator, resulting in a single graph.

Async Collectors and Weight Synchronisation

MultiSyncCollector and MultiAsyncCollector

MultiCollector spawns worker processes that each run their own Collector, environment, and policy copy. In sync mode all workers finish their batch before the main process proceeds; in async mode workers write data to shared memory as they complete, so the trainer never waits.
from torchrl.collectors import MultiCollector, MultiSyncCollector, MultiAsyncCollector

# Sync — deterministic, easier to debug
sync_collector = MultiCollector(
    [env_fn_1, env_fn_2, env_fn_3, env_fn_4],
    policy=policy,
    frames_per_batch=1000,
    total_frames=1_000_000,
    sync=True,                   # True → MultiSyncCollector
    num_workers_per_collector=1,
)

# Async — higher throughput, accepts stale data
async_collector = MultiCollector(
    [env_fn_1, env_fn_2, env_fn_3, env_fn_4],
    policy=policy,
    frames_per_batch=1000,
    total_frames=1_000_000,
    sync=False,                  # False → MultiAsyncCollector
)

Weight Synchronisation

After each optimizer step, push updated policy parameters to remote workers using WeightUpdaterBase:
from torchrl.collectors import VanillaWeightUpdater, MultiProcessedWeightUpdater

# VanillaWeightUpdater: copy state_dict to all workers
updater = VanillaWeightUpdater()
collector = MultiSyncCollector(
    ...,
    weight_updater=updater,
)

# After optimizer step
optimizer.step()
updater.update_weights()   # push to all worker processes
For Ray-distributed collectors:
from torchrl.collectors import RayWeightUpdater

ray_updater = RayWeightUpdater()

CUDA Prioritized Replay Buffers

TorchRL ships optional CUDA extension wheels that accelerate the sum-tree operations underlying PrioritizedSampler. When a CUDA-enabled TorchRL wheel is installed and the storage is on GPU, the sampler automatically routes to the CUDA tree.
from torchrl.data import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSampler

# GPU storage → CUDA sum-tree if available; CPU otherwise
buffer = TensorDictPrioritizedReplayBuffer(
    alpha=0.6,
    beta=0.4,
    storage=LazyTensorStorage(max_size=100_000, device="cuda"),
    batch_size=256,
)

buffer.extend(batch)
sample, info = buffer.sample(return_info=True)

# Update priorities with TD errors
td_errors = (sample["td_error"]).abs()
buffer.update_priorities(info["index"], td_errors)
Install the CUDA wheel matching your PyTorch build:
# Example — match your CUDA variant
pip install torchrl-nightly[cuda] --index-url https://download.pytorch.org/whl/cu121
CUDA prioritized replay requires TorchRL to be built with FORCE_CUDA=1 or installed via a CUDA wheel. Without the extension the sampler silently falls back to a CPU tree, which involves device transfers on every sample call.

Memory-Mapped Storage for Large Datasets

LazyMemmapStorage keeps tensors on disk as memory-mapped files — the OS maps them into the process address space on demand. This is the right choice for offline datasets too large to fit in RAM, or for distributed runs where multiple processes need shared access.
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage

storage = LazyMemmapStorage(
    max_size=10_000_000,
    scratch_dir="./replay_data",     # persistent directory (set for checkpointing)
    device="cpu",
    auto_cleanup=False,              # keep data after process exits
)

buffer = TensorDictReplayBuffer(
    storage=storage,
    batch_size=512,
    prefetch=8,                      # background prefetch threads
)

# Zero-copy checkpoint when scratch_dir matches the checkpoint dir
buffer.dumps("./replay_data")       # effectively a no-op if already at scratch_dir

Shared Memory for Multi-Process Environments

For EnvBase parallelism inside a single node, use share_memory() or pass use_shared_mem=True to ensure tensors live in shared memory and avoid inter-process copies:
from torchrl.envs import ParallelEnv

# Automatically uses shared memory for cross-process tensor access
vec_env = ParallelEnv(
    num_workers=8,
    create_env_fn=lambda: GymEnv("HalfCheetah-v4"),
    shared_memory=True,
)

Triton GRU/LSTM Scan Backends for Recurrent RL

LSTMModule and GRUModule support three execution backends for handling trajectory resets inside a batch (when done=True in the middle of a sequence):
BackendDescriptionBest for
"pad"Default: split, pad, and concat around resetsEager mode, highest compatibility
"scan"Scan loop via hoptorch — avoids materialising padded chunkstorch.compile
"triton"CUDA Triton kernels — fused forward/backwardGPU training, recurrent backprop speed
"auto""pad" in eager, "scan" under torch.compileDevelopment default
from torchrl.modules.tensordict_module.rnn import LSTMModule, GRUModule

# Triton backend — fastest on GPU, requires `pip install triton`
lstm = LSTMModule(
    input_size=64,
    hidden_size=256,
    num_layers=2,
    in_key="observation",
    out_key="lstm_out",
    recurrent_backend="triton",
    recurrent_recompute="full",          # trade compute for lower activation memory
    recurrent_compute_dtype=torch.float32,
)

gru = GRUModule(
    input_size=64,
    hidden_size=256,
    in_key="observation",
    out_key="gru_out",
    recurrent_backend="scan",            # compile-friendly scan loop
)
recurrent_recompute options:
  • "none" (default) — save all gate activations for backward; uses more memory.
  • "full" with Triton — drops per-step gate buffers and replays forward kernel during backward; significantly reduces VRAM for long sequences.
  • "full" with scan — wraps the scan in torch.utils.checkpoint.checkpoint.

Recurrent Precision Tuning

For the Triton backend you can control matmul precision:
# Control globally
from torchrl.modules import set_recurrent_matmul_precision
set_recurrent_matmul_precision("tf32")   # fastest on Ampere+

# Or per-module
lstm = LSTMModule(
    ...,
    recurrent_backend="triton",
    recurrent_matmul_precision="tf32x3",  # 22-bit compensated TF32
)

Profiling Collector Workers with ProfileConfig

ProfileConfig attaches PyTorch Profiler hooks to any collector — single or multi-process — and exports Chrome-trace JSON files per worker:
from torchrl.collectors import MultiSyncCollector, ProfileConfig

collector = MultiSyncCollector(
    env_fns=[env_fn] * 4,
    policy=policy,
    frames_per_batch=1000,
)

# Enable profiling for worker 0
collector.enable_profile(
    workers=[0],
    num_rollouts=5,          # profile for 5 rollouts (including warmup)
    warmup_rollouts=2,       # skip 2 rollouts for JIT warmup
    save_path="./traces/worker_{worker_idx}.json",
    activities=["cpu", "cuda"],
    record_shapes=True,
    profile_memory=True,
)

for data in collector:
    process(data)
    # After num_rollouts batches, the profiler stops and saves the trace
Open the JSON file in chrome://tracing or Perfetto UI to inspect kernel timings, memory allocations, and operator shapes. You can also construct a ProfileConfig object directly for more control:
from torchrl.collectors import ProfileConfig
from torch.profiler import ProfilerActivity

config = ProfileConfig(
    workers=[0, 1],
    num_rollouts=10,
    warmup_rollouts=3,
    activities=["cpu", "cuda"],
    record_shapes=True,
    profile_memory=False,
    with_stack=True,
    with_flops=True,
)
collector.enable_profile(config=config)

Compilable Replay Buffer Samplers

The PrioritizedSampler and SliceSampler inner index-computation methods can be compiled with torch.compile to reduce Python overhead at high batch rates:
from torchrl.data.replay_buffers.samplers import PrioritizedSampler

sampler = PrioritizedSampler(
    max_capacity=100_000,
    alpha=0.6,
    beta=0.4,
    compilable=True,         # wrap _get_index with torch.compile
)

End-to-End Performance Checklist

import torch
from torchrl.collectors import MultiCollector, VanillaWeightUpdater
from torchrl.data import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.objectives import PPOLoss
from torchrl.objectives.value import GAE
from torchrl._utils import compile_with_warmup

# 1. Vectorised advantage estimator
gae = GAE(gamma=0.99, lmbda=0.95, value_network=value_net)

# 2. Async multi-process collection
collector = MultiCollector(
    [env_fn] * 8,
    policy=policy,
    frames_per_batch=2048,
    sync=False,                          # async workers
    weight_updater=VanillaWeightUpdater(),
)

# 3. GPU prioritized replay (CUDA wheel required)
buffer = TensorDictPrioritizedReplayBuffer(
    alpha=0.6,
    beta=0.4,
    storage=LazyTensorStorage(max_size=500_000, device="cuda"),
    batch_size=512,
    prefetch=4,
)

# 4. Compiled loss with warmup
loss_fn = PPOLoss(actor_network=actor, critic_network=value_net)
compiled_loss = compile_with_warmup(loss_fn, warmup=2, mode="reduce-overhead")

# 5. Recurrent backbone with Triton backend
from torchrl.modules.tensordict_module.rnn import LSTMModule
lstm = LSTMModule(
    input_size=64, hidden_size=256,
    in_key="observation", out_key="lstm_out",
    recurrent_backend="triton",
    recurrent_recompute="full",
)

# Training loop
for data in collector:
    buffer.extend(data)
    if len(buffer) < 512:
        continue

    for _ in range(10):
        batch = buffer.sample()
        batch = gae(batch)
        loss_td = compiled_loss(batch)
        optimizer.zero_grad()
        loss_td["loss_objective"].backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
        optimizer.step()

    collector.update_policy_weights_()  # sync to workers
Run a short profiling session (5–10 rollouts) with ProfileConfig before committing to a long training run. The most common bottlenecks are CPU-GPU data transfers in the replay buffer, uncompiled loss functions, and recurrent padded batches with many short trajectories (use "scan" or "triton" backend in that case).

Key Imports Reference

# Value estimation
from torchrl.objectives.value import GAE, MultiAgentGAE
from torchrl.objectives.value.functional import (
    vec_generalized_advantage_estimate,
    vec_td_lambda_return_estimate,
    vec_td_lambda_advantage_estimate,
    vec_td1_return_estimate,
    vec_td1_advantage_estimate,
)

# Compilation utilities
from torchrl._utils import compile_with_warmup

# Collectors
from torchrl.collectors import (
    Collector,
    MultiCollector,
    MultiSyncCollector,
    MultiAsyncCollector,
    VanillaWeightUpdater,
    MultiProcessedWeightUpdater,
    RayWeightUpdater,
    ProfileConfig,
)

# Replay buffers and storage
from torchrl.data import TensorDictReplayBuffer, TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import PrioritizedSampler

# Recurrent modules
from torchrl.modules.tensordict_module.rnn import LSTMModule, GRUModule, LSTM, GRU
from torchrl.modules import set_recurrent_matmul_precision

Build docs developers (and LLMs) love