Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL’s replay buffers are fully composable: you independently choose a storage backend (how data is held in memory), a sampler (how indices are selected at sample time), and a writer (how new data is inserted). Optional transforms are applied after sampling, mirroring the environment transform API. This composability means you can swap any component without touching the rest of your training loop.

Architecture Overview

Every replay buffer is built from four orthogonal components wired together at construction time.
ReplayBuffer
├── Storage  — holds the raw data (list, tensor, memory-mapped, ...)
├── Sampler  — selects which indices to return (uniform, prioritized, slice, ...)
├── Writer   — decides where new data lands (round-robin, max-value, immutable, ...)
└── Transform — post-processes samples before returning them to the caller
All four components can be passed as either instances or callables (zero-argument factories). Passing a factory is useful when the buffer needs to be pickled and sent to remote workers, because the component is re-created inside the worker process.
from torchrl.data import (
    ReplayBuffer,
    TensorDictReplayBuffer,
    LazyTensorStorage,
    PrioritizedSampler,
    RoundRobinWriter,
)
from torchrl.envs import Compose, ObservationNorm

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(max_size=100_000),
    sampler=PrioritizedSampler(max_capacity=100_000, alpha=0.6, beta=0.4),
    writer=RoundRobinWriter(),
    batch_size=256,
)

Core Buffer Classes

ReplayBuffer

The generic, composable replay buffer base class. It accepts any kind of data — plain Python objects, tensors, PyTree structures, and TensorDicts — as long as the chosen storage backend can hold them. The default storage is ListStorage(max_size=1_000) and the default sampler is RandomSampler. Import path: torchrl.data.ReplayBuffer
storage
Storage | Callable[[], Storage]
The storage backend. If a callable is passed it is used as a constructor. Defaults to ListStorage(max_size=1_000).
sampler
Sampler | Callable[[], Sampler]
The index sampler. Defaults to RandomSampler.
writer
Writer | Callable[[], Writer]
The write policy. Defaults to RoundRobinWriter.
batch_size
int
Fixed batch size used when sample() is called without arguments. Specifying this in advance enables prefetching. Cannot be combined with samplers that have a drop_last argument unless set here.
transform
Transform | Callable
A torchrl.envs.Transform (or plain callable for non-TensorDict data) applied to every sampled batch. Chain multiple transforms with torchrl.envs.Compose. Mutually exclusive with transform_factory.
transform_factory
Callable[[], Transform | Callable]
A zero-argument factory for the transform. Mutually exclusive with transform. When provided, delayed_init defaults to True so the buffer can be safely pickled before initialization.
pin_memory
bool
default:"False"
Whether to call pin_memory() on sampled batches. Useful for CPU→GPU transfers.
prefetch
int
Number of batches to prefetch using a background thread pool. Requires batch_size to be set at construction time.
collate_fn
Callable
Merges a list of individual samples into a mini-batch. Defaults are inferred from the storage type.
dim_extend
int
Dimension along which extend() iterates when inserting multi-dimensional data. Defaults to storage.ndim - 1. Has no effect on add().
generator
torch.Generator
Dedicated RNG for sampling. Enables reproducible, per-buffer seeds in distributed jobs without affecting the global generator.
consume_after_n_samples
int
Remove items from the sampleable set after they have been returned this many times. 1 turns the buffer into a one-shot queue. Cannot be combined with prefetching.
shared
bool
default:"False"
Share the buffer across processes via shared memory. Incompatible with prefetching.
compilable
bool
default:"False"
Make the writer compatible with torch.compile. Disables multi-process sharing.
delayed_init
bool
Defer component initialization until the buffer is first used. Defaults to True when transform_factory is provided, False otherwise.
Key methods:
MethodDescription
add(data)Insert a single item; returns its storage index.
extend(data)Insert a batch of items; returns a tensor of indices.
sample(batch_size=None)Return a batch. Uses constructor batch_size if not provided.
__len__()Current number of items stored.
update_priority(index, priority)Update priorities (only relevant when using PrioritizedSampler).
dumps(path) / loads(path)Checkpoint and restore the buffer to/from disk.
append_transform(t)Append a transform to the existing transform chain.
import torch
from torchrl.data import ReplayBuffer, ListStorage

rb = ReplayBuffer(
    storage=ListStorage(max_size=1000),
    batch_size=32,
)

# extend with any iterable
rb.extend(range(200))

# sample returns a tensor of 32 items
batch = rb.sample()
print(batch.shape)  # torch.Size([32])

# iterate — exhausts if SamplerWithoutReplacement is used
for i, batch in enumerate(rb):
    print(i, batch)
    if i == 2:
        break

TensorDictReplayBuffer

A TensorDict-aware wrapper around ReplayBuffer. The primary difference from the base class is that:
  • The default writer is TensorDictRoundRobinWriter (which properly handles nested TensorDict keys).
  • Sampled batches automatically include an "index" key containing the storage indices used, enabling in-place priority updates.
  • Accepts a priority_key argument pointing to where the TD-error lives inside each TensorDict so that update_tensordict_priority(sample) can read it automatically.
Import path: torchrl.data.TensorDictReplayBuffer
priority_key
str
default:"\"td_error\""
Key inside TensorDicts where the TD-error (or other priority signal) is stored. Used by update_tensordict_priority.
All other keyword arguments are identical to ReplayBuffer.

TensorDictPrioritizedReplayBuffer

Convenience class that wires TensorDictReplayBuffer together with a PrioritizedSampler. All PER-related plumbing (priority tree initialization, weight computation, tensordict key injection) is handled automatically. Import path: torchrl.data.TensorDictPrioritizedReplayBuffer
alpha
float
required
Prioritization exponent. 0 = uniform sampling, 1 = fully prioritized. Typical range: 0.4–0.7.
beta
float
required
Importance-sampling correction exponent. Start at 0.4 and anneal to 1.0 over training.
eps
float
default:"1e-8"
Minimum priority added to every entry to prevent zero-probability sampling.
priority_key
NestedKey
default:"\"td_error\""
Key from which update_tensordict_priority reads TD-errors.
storage
Storage | Callable
Storage backend. Defaults to ListStorage(max_size=1_000).
sampler_device
str | torch.device
Device where the priority sampler trees are stored. Defaults to None (inferred from the storage device).
sync
bool
default:"True"
If False, writer processes use a RandomSampler and the learner owns a local prioritized sampler, enabling async PER across processes.
reduction
str
default:"\"max\""
How to reduce multi-step priorities ("max", "min", "mean", "median").
import torch
from tensordict import TensorDict
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage

rb = TensorDictPrioritizedReplayBuffer(
    alpha=0.7,
    beta=1.1,
    storage=LazyTensorStorage(max_size=10_000),
    batch_size=256,
)

data = TensorDict({"obs": torch.randn(1000, 8), "action": torch.randn(1000, 2)}, [1000])
rb.extend(data)

sample = rb.sample()
# sample["index"] contains buffer indices
# sample["priority_weight"] contains IS weights

# Update priorities after computing TD-errors:
sample.set("td_error", torch.rand(sample.shape))
rb.update_tensordict_priority(sample)

HERReplayBuffer

Hindsight Experience Replay buffer (Andrychowicz et al., NeurIPS 2017). At sample time, a fraction her_ratio of the batch has its desired goal replaced with an achieved goal drawn from the same episode (or from the full buffer when strategy="random"). The reward is then recomputed by calling reward_fn. Import path: torchrl.data.HERReplayBuffer
reward_fn
Callable[[TensorDictBase], Tensor]
required
Receives a relabeled TensorDict (with the new goal already set at goal_key) and must return a reward tensor of shape (*batch, 1) or (*batch,).
her_ratio
float
default:"0.8"
Fraction of the batch to relabel. Must be in [0, 1].
strategy
HindsightStrategy | str
default:"\"future\""
Goal selection strategy. One of "future", "final", "episode", "random".
goal_key
NestedKey
default:"\"desired_goal\""
Key for the desired goal field in each transition TensorDict.
achieved_goal_key
NestedKey
default:"\"achieved_goal\""
Key for the achieved goal field in each transition TensorDict.
reward_key
NestedKey
default:"(\"next\", \"reward\")"
Key where the recomputed reward is written. Follows the TorchRL TED convention.
end_key
NestedKey
default:"(\"next\", \"done\")"
Key containing the episode-boundary signal (done flag).
sampler
Sampler
Underlying index sampler. Defaults to RandomSampler.
All remaining TensorDictReplayBuffer keyword arguments are accepted.

HindsightStrategy

from torchrl.data import HindsightStrategy

# Available strategies:
HindsightStrategy.FUTURE   # Sample goal from a future step in the same episode (recommended)
HindsightStrategy.FINAL    # Use the final achieved state as goal
HindsightStrategy.EPISODE  # Sample any achieved state from the same episode uniformly
HindsightStrategy.RANDOM   # Sample a random achieved state from the entire buffer
import torch
from tensordict import TensorDict
from torchrl.data import HERReplayBuffer, HindsightStrategy, LazyTensorStorage

def my_reward_fn(td):
    # Compute sparse reward: 1 if achieved_goal ≈ desired_goal
    dist = (td["achieved_goal"] - td["desired_goal"]).norm(dim=-1, keepdim=True)
    return (dist < 0.05).float()

rb = HERReplayBuffer(
    reward_fn=my_reward_fn,
    her_ratio=0.8,
    strategy=HindsightStrategy.FUTURE,
    storage=LazyTensorStorage(max_size=50_000),
    batch_size=256,
)

Other Buffer Variants

ClassDescription
PrioritizedReplayBufferBase-class version of the prioritized buffer for non-TensorDict data.
ReplayBufferEnsembleSamples from multiple buffers simultaneously, useful for multi-task setups.
OfflineToOnlineReplayBufferCombines a static offline dataset with an online buffer, with a configurable mixing ratio.
RayReplayBufferDistributed replay buffer backed by a Ray actor, for multi-machine training.
RemoteTensorDictReplayBufferRPC-based remote replay buffer using torch.distributed.rpc.

Storage Classes

Storage backends determine how data is laid out in memory and what data types are supported.

ListStorage

Stores items in a Python list. Supports arbitrary Python objects — the only storage that accepts non-tensor, non-TensorDict data. Cannot be extended with PyTree structures via extend(). Import path: torchrl.data.ListStorage
max_size
int
Maximum number of elements. Defaults to torch.iinfo(torch.int64).max (effectively unlimited).
compilable
bool
default:"False"
Enable torch.compile compatibility. Disables multi-process sharing.
device
str | torch.device
If provided, data with a .to() method is moved to this device on write.

LazyTensorStorage

A pre-allocated contiguous tensor (or TensorDict) storage. Allocation is deferred until the first call to set(), at which point the storage infers shape from the first batch. All subsequent writes must be shape-compatible. Supports PyTree structures as well as TensorDict and tensorclass. Import path: torchrl.data.LazyTensorStorage
max_size
int
required
Maximum number of stored elements along the primary (index) dimension.
device
str | torch.device
default:"\"cpu\""
Device for stored tensors. Pass "auto" to infer from the first inserted batch.
ndim
int
default:"1"
Number of dimensions defining storage capacity. A storage of shape [3, 4] has capacity 3 with ndim=1 and 12 with ndim=2. Keep ndim=1 when using trajs_per_batch collectors.
compilable
bool
default:"False"
Enable torch.compile compatibility.
consolidated
bool
default:"False"
Consolidate the storage after first expansion. Reduces memory fragmentation for TensorDict data.

LazyMemmapStorage

Memory-mapped variant of LazyTensorStorage. Data is written to disk and read back as MemoryMappedTensor objects. Ideal for very large replay buffers that exceed available RAM. Automatic cleanup of temporary files is handled on process exit. Import path: torchrl.data.LazyMemmapStorage
max_size
int
required
Maximum number of stored elements.
scratch_dir
str | Path
Directory for memory-mapped files. If omitted, a temporary directory is created and cleaned up automatically on exit.
device
str | torch.device
default:"\"cpu\""
Device for sampled tensors.
ndim
int
default:"1"
Number of dimensions defining storage capacity. See LazyTensorStorage.ndim.
existsok
bool
default:"False"
Allow re-opening existing memmap files without overwriting.
auto_cleanup
bool
Whether to delete the memmap files when the process exits. Defaults to True for temporary directories and False for user-specified paths.
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer

# 10-million-step buffer backed by disk — no RAM limit
rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(
        max_size=10_000_000,
        scratch_dir="/scratch/my_experiment/rb",
    ),
    batch_size=512,
)
When checkpointing a LazyMemmapStorage, pass the same directory that the storage already uses as scratch_dir. TorchRL detects this and performs a zero-copy checkpoint — the data is already on disk.

Other Storage Classes

ClassDescription
TensorStorageBase class for contiguous tensor-backed storages.
CompressedListStorageList storage with on-the-fly compression/decompression.
LazyStackStorageStacks heterogeneous tensor structures lazily.
StoreStorageRedis-backed remote storage (requires redis and tensordict.store).

Sampler Classes

Samplers control which buffer indices are returned on each call to sample().

RandomSampler

Uniform random sampling with replacement. The fastest sampler and the default. Pass replacement=False to automatically dispatch to SamplerWithoutReplacement. Import path: torchrl.data.RandomSampler
from torchrl.data import RandomSampler, SamplerWithoutReplacement

# These two are equivalent:
s1 = RandomSampler(replacement=False)
s2 = SamplerWithoutReplacement()
isinstance(s1, SamplerWithoutReplacement)  # True

SamplerWithoutReplacement

Epochs through all buffer indices before repeating. Useful for supervised-style offline training. When drop_last=True, any remainder batch smaller than batch_size is discarded. Import path: torchrl.data.SamplerWithoutReplacement
drop_last
bool
default:"False"
Discard the last incomplete batch within an epoch.
shuffle
bool
default:"True"
Shuffle indices at the start of each epoch.

PrioritizedSampler

Implements Prioritized Experience Replay (Schaul et al., ICLR 2016). Samples proportionally to p_i^alpha and returns importance-sampling weights w_i = (N * P(i))^(-beta). Import path: torchrl.data.PrioritizedSampler
max_capacity
int
required
Must match the storage’s max_size.
alpha
float
required
Prioritization exponent. 0 = uniform, 1 = fully proportional. Typical: 0.4–0.7.
beta
float
required
IS correction exponent. Anneal from ~0.4 to 1.0 over training. 0 = no correction.
eps
float
default:"1e-8"
Small constant added to every priority to prevent zero probabilities.
reduction
str
default:"\"max\""
Aggregation over trajectory dimensions: "max", "min", "mean", or "median".
device
str | torch.device
Device for the internal segment trees. Defaults to matching the storage device.

SliceSampler

Samples contiguous sub-trajectories (slices) from a buffer containing episode data. Useful for recurrent policies and sequence models. Use num_slices to control how many trajectories appear in each batch, or slice_len to fix the length of each slice. Import path: torchrl.data.SliceSampler
num_slices
int
Number of separate trajectories per batch. Exclusive with slice_len.
slice_len
int
Length of each sampled slice. Exclusive with num_slices.
end_key
NestedKey
default:"(\"next\", \"done\")"
Key that marks episode boundaries.
traj_key
NestedKey
Key storing trajectory IDs. Used when end_key is absent or expensive to read. Defaults to None.
strict_length
bool
default:"True"
Exclude trajectories shorter than slice_len. Set to False to allow variable-length slices.
cache_values
bool
default:"False"
Cache episode-boundary indices. Safe to use for static datasets.
compile
bool | dict
default:"False"
JIT-compile the sampling kernel via torch.compile.
Pass replacement=False to dispatch to SliceSamplerWithoutReplacement.

PrioritizedSliceSampler

Combines SliceSampler and PrioritizedSampler: slices are sampled proportionally to trajectory-level priorities. Import path: torchrl.data.PrioritizedSliceSampler

Other Sampler Classes

ClassDescription
SliceSamplerWithoutReplacementEpoch-based slice sampler without replacement.
ConsumingSamplerRemoves items after max_sample_count returns.
StalenessAwareSamplerFilters items that are stale relative to a target network update counter.
SamplerEnsembleCombines samplers from multiple storages.

Writer Classes

Writers decide where in the storage each new item lands.

RoundRobinWriter

The default writer. Maintains a circular cursor that wraps around when the buffer is full. The oldest item is always overwritten next. Import path: torchrl.data.RoundRobinWriter
compilable
bool
default:"False"
Enable torch.compile compatibility.

TensorDictRoundRobinWriter

TensorDict-aware version of RoundRobinWriter. The default writer for TensorDictReplayBuffer. Supports writing nested TensorDict keys correctly. Import path: torchrl.data.TensorDictRoundRobinWriter

TensorDictMaxValueWriter

Inserts a new item only if its priority value exceeds the current minimum in the buffer. Useful for maintaining a “best-of” buffer of high-quality experiences. Import path: torchrl.data.TensorDictMaxValueWriter
rank_key
NestedKey
Key containing the scalar priority used for comparison. Defaults to "reward".

Other Writer Classes

ClassDescription
ImmutableDatasetWriterRaises an error on any write attempt — used for read-only offline datasets.
WriterEnsembleDispatches writes across multiple storages.

Complete Training Loop Example

The following example shows a typical online RL training loop with a TensorDictReplayBuffer and Prioritized Experience Replay.
import torch
from tensordict import TensorDict
from torchrl.data import (
    TensorDictPrioritizedReplayBuffer,
    LazyTensorStorage,
)
from torchrl.envs import GymEnv, TransformedEnv, StepCounter

# --- Environment ---
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())

# --- Replay buffer ---
rb = TensorDictPrioritizedReplayBuffer(
    alpha=0.7,
    beta=0.5,
    storage=LazyTensorStorage(max_size=50_000),
    batch_size=256,
)

# --- Warm-up: fill buffer with random transitions ---
td = env.reset()
for _ in range(1_000):
    td = env.rand_step(td)
    rb.add(td.clone())
    if td["next", "done"].any():
        td = env.reset()

# --- Training loop ---
for step in range(10_000):
    # Collect one transition
    td = env.rand_step(td)
    rb.add(td.clone())
    if td["next", "done"].any():
        td = env.reset()

    # Sample and train
    sample = rb.sample()
    # ... compute loss, td_errors ...
    td_errors = torch.rand(sample.shape)  # placeholder
    sample.set("td_error", td_errors)
    rb.update_tensordict_priority(sample)

    # Anneal beta toward 1.0
    rb._sampler._beta = min(1.0, 0.5 + step / 10_000 * 0.5)

Checkpointing

All replay buffers support save/restore via dumps and loads:
from pathlib import Path

# Save the full buffer state (storage + sampler + writer metadata)
rb.dumps(Path("./checkpoints/replay_buffer"))

# Restore from disk
rb.loads(Path("./checkpoints/replay_buffer"))
LazyMemmapStorage buffers can be checkpointed zero-copy by passing the same directory used as scratch_dir. If a different path is given, TorchRL copies the data, which can be slow for large buffers.

Build docs developers (and LLMs) love