Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL replay buffers are built around a simple principle: storage, sampling, writing, and post-sampling transforms are independent, swappable components. You can combine any storage backend with any sampler, attach transforms that run at sample time, enable multi-threaded prefetching, or make the buffer remote and distributed — all without changing the calling code. The same extend / sample API works for simple in-memory replay, disk-backed memmap storage, prioritized replay, hindsight experience relabeling, and offline dataset loading.

Architecture

ReplayBuffer
  ├── Storage    — where tensors are physically kept (RAM, mmap, GPU)
  ├── Sampler    — how indices are chosen (uniform, prioritized, slice)
  ├── Writer     — how new data enters the buffer (round-robin, max-value)
  └── Transform  — post-processing applied to every sampled batch
Each piece is instantiated independently and composed at construction time. TorchRL ships defaults for all four so you only need to override the parts you care about.

ReplayBuffer: the generic base

ReplayBuffer is the base class that works with any Python object. It uses a ListStorage and RandomSampler by default.
from torchrl.data import ReplayBuffer, ListStorage, RandomSampler

rb = ReplayBuffer(
    storage=ListStorage(max_size=10_000),
    sampler=RandomSampler(),
    batch_size=128,
)

rb.add({"obs": torch.randn(8), "action": torch.randn(2)})    # single transition
rb.extend([{"obs": torch.randn(8), "action": torch.randn(2)} for _ in range(100)])
batch = rb.sample()  # returns list of 128 dicts

TensorDictReplayBuffer

TensorDictReplayBuffer is the TensorDict-aware wrapper you should use in almost all TorchRL workflows. It understands nested keys, preserves structure across extend / sample round-trips, and integrates with transforms that operate on TensorDicts.
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage

buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(max_size=100_000),
    batch_size=256,
    prefetch=4,   # background threads prefetch next batches
)

# Extend from a collector batch (TensorDict).
buffer.extend(collector_batch)

# Sample returns a TensorDict with structure preserved.
sample = buffer.sample()
print(sample["observation"].shape)      # [256, obs_dim]
print(sample["next", "reward"].shape)   # [256, 1]
LazyTensorStorage infers tensor shapes and dtypes from the first extend call. You never need to pre-declare the schema — the buffer adapts to whatever TensorDict structure it receives.

TensorDictPrioritizedReplayBuffer

TensorDictPrioritizedReplayBuffer weights samples by a priority stored under "td_error" (configurable). Losses that compute TD error write it back to the TensorDict; the buffer uses update_tensordict_priority() to register the new weights.
from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBuffer

buffer = TensorDictPrioritizedReplayBuffer(
    storage=LazyMemmapStorage(1_000_000),
    alpha=0.7,          # prioritization exponent (0 = uniform)
    beta=0.5,           # importance-sampling exponent
    eps=1e-6,           # added to priorities to avoid zeros
    batch_size=256,
    prefetch=2,
)

buffer.extend(collector_batch)
sample = buffer.sample()

# After computing TD errors, update priorities.
sample["td_error"] = td_errors.abs().detach()
buffer.update_tensordict_priority(sample)
The returned sample contains an "index" key that update_tensordict_priority uses to write updated priorities back to the correct tree positions.

Storage backends

from torchrl.data import ListStorage

# Stores data as a Python list of objects.
# Flexible but not optimized for large batches or GPU workflows.
storage = ListStorage(max_size=10_000)
Best for: small buffers, non-tensor data, heterogeneous batch shapes.

Samplers

RandomSampler

Uniform random sampling with replacement. The default sampler.
from torchrl.data import RandomSampler
sampler = RandomSampler()

SamplerWithoutReplacement

Samples all stored indices once before repeating — useful for offline RL or when you want a proper epoch structure.
from torchrl.data import SamplerWithoutReplacement
sampler = SamplerWithoutReplacement(drop_last=True)

PrioritizedSampler

Samples proportionally to stored priorities using a sum-tree.
from torchrl.data import PrioritizedSampler
sampler = PrioritizedSampler(
    max_capacity=100_000, alpha=0.7, beta=0.5
)

SliceSampler

Samples contiguous sub-trajectories of a fixed length. Pairs naturally with LazyStackStorage for recurrent RL.
from torchrl.data import SliceSampler
sampler = SliceSampler(num_slices=32, slice_len=64)

PrioritizedSliceSampler

Combines priority-based sampling with trajectory slicing — each slice is sampled from an episode with probability proportional to the maximum TD error in that episode.
from torchrl.data import PrioritizedSliceSampler, TensorDictReplayBuffer, LazyMemmapStorage

buffer = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(max_size=500_000),
    sampler=PrioritizedSliceSampler(
        max_capacity=500_000,
        alpha=0.7,
        beta=0.5,
        num_slices=16,
        slice_len=80,
    ),
    batch_size=16 * 80,
)

Writers

WriterBehaviour
RoundRobinWriterOverwrites oldest data when the buffer is full (circular buffer). Default for ReplayBuffer.
TensorDictRoundRobinWriterTensorDict-aware round-robin writer. Default for TensorDictReplayBuffer.
TensorDictMaxValueWriterOnly writes a new transition if its priority exceeds the minimum priority currently stored. Useful for competitive replay.
ImmutableDatasetWriterRaises an error on any write attempt — for read-only offline datasets.

HERReplayBuffer: hindsight experience replay

HERReplayBuffer applies goal relabeling at sample time. For a configurable fraction of each batch, the desired goal is replaced with an achieved goal from the same episode, and the reward is recomputed via a user-supplied function. The relabeling strategy can be "future", "final", "episode", or "random".
import torch
from tensordict import TensorDict
from torchrl.data import HERReplayBuffer, LazyMemmapStorage


def reward_fn(td):
    dist = (td["achieved_goal"] - td["desired_goal"]).norm(dim=-1, keepdim=True)
    return (dist < 0.05).float()


buffer = HERReplayBuffer(
    reward_fn=reward_fn,
    storage=LazyMemmapStorage(max_size=1_000_000),
    batch_size=256,
    her_ratio=0.8,          # fraction of batch to relabel
    strategy="future",      # use future achieved goals
    goal_key="desired_goal",
    achieved_goal_key="achieved_goal",
)

# Store a transition.
td = TensorDict(
    {
        "observation": torch.randn(4),
        "desired_goal": torch.zeros(3),
        "achieved_goal": torch.randn(3),
        "action": torch.randn(2),
        "next": {
            "observation": torch.randn(4),
            "desired_goal": torch.zeros(3),
            "achieved_goal": torch.randn(3),
            "reward": torch.zeros(1),
            "done": torch.tensor([False]),
        },
    },
    batch_size=[],
)
buffer.add(td)

Buffer transforms

Transforms can be attached to a replay buffer and execute at sample() time. This is useful for data augmentation, observation normalization that depends on the sampled batch, or on-the-fly device transfers.
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage
from torchrl.envs import ObservationNorm, Compose

buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(max_size=100_000),
    batch_size=256,
    transform=Compose(
        ObservationNorm(in_keys=["observation"], loc=0.0, scale=1.0),
    ),
)
Transforms attached to a replay buffer run on the sampled mini-batch, not at insert time. This means the normalization statistics can be updated between samples without re-writing the buffer.

Offline datasets

TorchRL supports loading and sampling from offline datasets using the same buffer interface. The ImmutableDatasetWriter prevents any further writes after loading.
from torchrl.data import (
    TensorDictReplayBuffer,
    LazyMemmapStorage,
    ImmutableDatasetWriter,
    SamplerWithoutReplacement,
)

# Load a pre-collected offline dataset from a previously saved buffer.
storage = LazyMemmapStorage(max_size=1_000_000)
storage.loads("/path/to/dataset")   # loads checkpoint into storage in-place

dataset = TensorDictReplayBuffer(
    storage=storage,
    sampler=SamplerWithoutReplacement(drop_last=True),
    writer=ImmutableDatasetWriter(),
    batch_size=256,
)

for batch in dataset:
    offline_loss = loss_fn(batch)
    optimizer.step()

Prefetching and distributed replay

Setting prefetch > 0 launches background threads that pre-load the next prefetch batches while your training code processes the current one. For distributed setups, RayReplayBuffer and RemoteTensorDictReplayBuffer expose the same extend / sample API across Ray actors or RPC processes.
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyMemmapStorage

# prefetch=2 launches 2 background sampler threads.
buffer = TensorDictPrioritizedReplayBuffer(
    storage=LazyMemmapStorage(1_000_000),
    alpha=0.7,
    beta=0.5,
    batch_size=256,
    prefetch=2,
)

Putting it together: SAC replay buffer setup

import torch
from torchrl.data import (
    LazyMemmapStorage,
    TensorDictPrioritizedReplayBuffer,
)
from torchrl.collectors import MultiAsyncCollector
from torchrl.envs.libs.gym import GymEnv

# Off-policy buffer with prioritized sampling and disk-backed storage.
buffer = TensorDictPrioritizedReplayBuffer(
    storage=LazyMemmapStorage(max_size=1_000_000),
    alpha=0.7,
    beta=0.5,
    batch_size=256,
    prefetch=4,
)

collector = MultiAsyncCollector(
    create_env_fn=[lambda: GymEnv("HalfCheetah-v4")] * 4,
    policy=policy,
    frames_per_batch=256,
    total_frames=5_000_000,
)

for batch in collector:
    buffer.extend(batch)
    if len(buffer) > 10_000:
        sample = buffer.sample()
        losses = loss_fn(sample)
        # Compute TD errors and update priorities.
        sample["td_error"] = losses["td_error"].detach()
        buffer.update_tensordict_priority(sample)
        optimizer.step()
        collector.update_policy_weights_()

Build docs developers (and LLMs) love