TorchRL’s replay buffers are fully composable: you independently choose a storage backend (how data is held in memory), a sampler (how indices are selected at sample time), and a writer (how new data is inserted). Optional transforms are applied after sampling, mirroring the environment transform API. This composability means you can swap any component without touching the rest of your training loop.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
Architecture Overview
Every replay buffer is built from four orthogonal components wired together at construction time.All four components can be passed as either instances or callables (zero-argument factories). Passing a factory is useful when the buffer needs to be pickled and sent to remote workers, because the component is re-created inside the worker process.
Core Buffer Classes
ReplayBuffer
The generic, composable replay buffer base class. It accepts any kind of data — plain Python objects, tensors, PyTree structures, and TensorDicts — as long as the chosen storage backend can hold them. The default storage is ListStorage(max_size=1_000) and the default sampler is RandomSampler.
Import path: torchrl.data.ReplayBuffer
The storage backend. If a callable is passed it is used as a constructor. Defaults to
ListStorage(max_size=1_000).The index sampler. Defaults to
RandomSampler.The write policy. Defaults to
RoundRobinWriter.Fixed batch size used when
sample() is called without arguments. Specifying this in advance enables prefetching. Cannot be combined with samplers that have a drop_last argument unless set here.A
torchrl.envs.Transform (or plain callable for non-TensorDict data) applied to every sampled batch. Chain multiple transforms with torchrl.envs.Compose. Mutually exclusive with transform_factory.A zero-argument factory for the transform. Mutually exclusive with
transform. When provided, delayed_init defaults to True so the buffer can be safely pickled before initialization.Whether to call
pin_memory() on sampled batches. Useful for CPU→GPU transfers.Number of batches to prefetch using a background thread pool. Requires
batch_size to be set at construction time.Merges a list of individual samples into a mini-batch. Defaults are inferred from the storage type.
Dimension along which
extend() iterates when inserting multi-dimensional data. Defaults to storage.ndim - 1. Has no effect on add().Dedicated RNG for sampling. Enables reproducible, per-buffer seeds in distributed jobs without affecting the global generator.
Remove items from the sampleable set after they have been returned this many times.
1 turns the buffer into a one-shot queue. Cannot be combined with prefetching.Share the buffer across processes via shared memory. Incompatible with prefetching.
Make the writer compatible with
torch.compile. Disables multi-process sharing.Defer component initialization until the buffer is first used. Defaults to
True when transform_factory is provided, False otherwise.| Method | Description |
|---|---|
add(data) | Insert a single item; returns its storage index. |
extend(data) | Insert a batch of items; returns a tensor of indices. |
sample(batch_size=None) | Return a batch. Uses constructor batch_size if not provided. |
__len__() | Current number of items stored. |
update_priority(index, priority) | Update priorities (only relevant when using PrioritizedSampler). |
dumps(path) / loads(path) | Checkpoint and restore the buffer to/from disk. |
append_transform(t) | Append a transform to the existing transform chain. |
TensorDictReplayBuffer
A TensorDict-aware wrapper around ReplayBuffer. The primary difference from the base class is that:
- The default writer is
TensorDictRoundRobinWriter(which properly handles nested TensorDict keys). - Sampled batches automatically include an
"index"key containing the storage indices used, enabling in-place priority updates. - Accepts a
priority_keyargument pointing to where the TD-error lives inside each TensorDict so thatupdate_tensordict_priority(sample)can read it automatically.
torchrl.data.TensorDictReplayBuffer
Key inside TensorDicts where the TD-error (or other priority signal) is stored. Used by
update_tensordict_priority.ReplayBuffer.
TensorDictPrioritizedReplayBuffer
Convenience class that wires TensorDictReplayBuffer together with a PrioritizedSampler. All PER-related plumbing (priority tree initialization, weight computation, tensordict key injection) is handled automatically.
Import path: torchrl.data.TensorDictPrioritizedReplayBuffer
Prioritization exponent.
0 = uniform sampling, 1 = fully prioritized. Typical range: 0.4–0.7.Importance-sampling correction exponent. Start at
0.4 and anneal to 1.0 over training.Minimum priority added to every entry to prevent zero-probability sampling.
Key from which
update_tensordict_priority reads TD-errors.Storage backend. Defaults to
ListStorage(max_size=1_000).Device where the priority sampler trees are stored. Defaults to
None (inferred from the storage device).If
False, writer processes use a RandomSampler and the learner owns a local prioritized sampler, enabling async PER across processes.How to reduce multi-step priorities (
"max", "min", "mean", "median").HERReplayBuffer
Hindsight Experience Replay buffer (Andrychowicz et al., NeurIPS 2017). At sample time, a fraction her_ratio of the batch has its desired goal replaced with an achieved goal drawn from the same episode (or from the full buffer when strategy="random"). The reward is then recomputed by calling reward_fn.
Import path: torchrl.data.HERReplayBuffer
Receives a relabeled TensorDict (with the new goal already set at
goal_key) and must return a reward tensor of shape (*batch, 1) or (*batch,).Fraction of the batch to relabel. Must be in
[0, 1].Goal selection strategy. One of
"future", "final", "episode", "random".Key for the desired goal field in each transition TensorDict.
Key for the achieved goal field in each transition TensorDict.
Key where the recomputed reward is written. Follows the TorchRL TED convention.
Key containing the episode-boundary signal (done flag).
Underlying index sampler. Defaults to
RandomSampler.TensorDictReplayBuffer keyword arguments are accepted.
HindsightStrategy
Other Buffer Variants
| Class | Description |
|---|---|
PrioritizedReplayBuffer | Base-class version of the prioritized buffer for non-TensorDict data. |
ReplayBufferEnsemble | Samples from multiple buffers simultaneously, useful for multi-task setups. |
OfflineToOnlineReplayBuffer | Combines a static offline dataset with an online buffer, with a configurable mixing ratio. |
RayReplayBuffer | Distributed replay buffer backed by a Ray actor, for multi-machine training. |
RemoteTensorDictReplayBuffer | RPC-based remote replay buffer using torch.distributed.rpc. |
Storage Classes
Storage backends determine how data is laid out in memory and what data types are supported.ListStorage
Stores items in a Python list. Supports arbitrary Python objects — the only storage that accepts non-tensor, non-TensorDict data. Cannot be extended with PyTree structures via extend().
Import path: torchrl.data.ListStorage
Maximum number of elements. Defaults to
torch.iinfo(torch.int64).max (effectively unlimited).Enable
torch.compile compatibility. Disables multi-process sharing.If provided, data with a
.to() method is moved to this device on write.LazyTensorStorage
A pre-allocated contiguous tensor (or TensorDict) storage. Allocation is deferred until the first call to set(), at which point the storage infers shape from the first batch. All subsequent writes must be shape-compatible. Supports PyTree structures as well as TensorDict and tensorclass.
Import path: torchrl.data.LazyTensorStorage
Maximum number of stored elements along the primary (index) dimension.
Device for stored tensors. Pass
"auto" to infer from the first inserted batch.Number of dimensions defining storage capacity. A storage of shape
[3, 4] has capacity 3 with ndim=1 and 12 with ndim=2. Keep ndim=1 when using trajs_per_batch collectors.Enable
torch.compile compatibility.Consolidate the storage after first expansion. Reduces memory fragmentation for TensorDict data.
LazyMemmapStorage
Memory-mapped variant of LazyTensorStorage. Data is written to disk and read back as MemoryMappedTensor objects. Ideal for very large replay buffers that exceed available RAM. Automatic cleanup of temporary files is handled on process exit.
Import path: torchrl.data.LazyMemmapStorage
Maximum number of stored elements.
Directory for memory-mapped files. If omitted, a temporary directory is created and cleaned up automatically on exit.
Device for sampled tensors.
Number of dimensions defining storage capacity. See
LazyTensorStorage.ndim.Allow re-opening existing memmap files without overwriting.
Whether to delete the memmap files when the process exits. Defaults to
True for temporary directories and False for user-specified paths.Other Storage Classes
| Class | Description |
|---|---|
TensorStorage | Base class for contiguous tensor-backed storages. |
CompressedListStorage | List storage with on-the-fly compression/decompression. |
LazyStackStorage | Stacks heterogeneous tensor structures lazily. |
StoreStorage | Redis-backed remote storage (requires redis and tensordict.store). |
Sampler Classes
Samplers control which buffer indices are returned on each call tosample().
RandomSampler
Uniform random sampling with replacement. The fastest sampler and the default. Pass replacement=False to automatically dispatch to SamplerWithoutReplacement.
Import path: torchrl.data.RandomSampler
SamplerWithoutReplacement
Epochs through all buffer indices before repeating. Useful for supervised-style offline training. When drop_last=True, any remainder batch smaller than batch_size is discarded.
Import path: torchrl.data.SamplerWithoutReplacement
Discard the last incomplete batch within an epoch.
Shuffle indices at the start of each epoch.
PrioritizedSampler
Implements Prioritized Experience Replay (Schaul et al., ICLR 2016). Samples proportionally to p_i^alpha and returns importance-sampling weights w_i = (N * P(i))^(-beta).
Import path: torchrl.data.PrioritizedSampler
Must match the storage’s
max_size.Prioritization exponent.
0 = uniform, 1 = fully proportional. Typical: 0.4–0.7.IS correction exponent. Anneal from
~0.4 to 1.0 over training. 0 = no correction.Small constant added to every priority to prevent zero probabilities.
Aggregation over trajectory dimensions:
"max", "min", "mean", or "median".Device for the internal segment trees. Defaults to matching the storage device.
SliceSampler
Samples contiguous sub-trajectories (slices) from a buffer containing episode data. Useful for recurrent policies and sequence models. Use num_slices to control how many trajectories appear in each batch, or slice_len to fix the length of each slice.
Import path: torchrl.data.SliceSampler
Number of separate trajectories per batch. Exclusive with
slice_len.Length of each sampled slice. Exclusive with
num_slices.Key that marks episode boundaries.
Key storing trajectory IDs. Used when
end_key is absent or expensive to read. Defaults to None.Exclude trajectories shorter than
slice_len. Set to False to allow variable-length slices.Cache episode-boundary indices. Safe to use for static datasets.
JIT-compile the sampling kernel via
torch.compile.replacement=False to dispatch to SliceSamplerWithoutReplacement.
PrioritizedSliceSampler
Combines SliceSampler and PrioritizedSampler: slices are sampled proportionally to trajectory-level priorities.
Import path: torchrl.data.PrioritizedSliceSampler
Other Sampler Classes
| Class | Description |
|---|---|
SliceSamplerWithoutReplacement | Epoch-based slice sampler without replacement. |
ConsumingSampler | Removes items after max_sample_count returns. |
StalenessAwareSampler | Filters items that are stale relative to a target network update counter. |
SamplerEnsemble | Combines samplers from multiple storages. |
Writer Classes
Writers decide where in the storage each new item lands.RoundRobinWriter
The default writer. Maintains a circular cursor that wraps around when the buffer is full. The oldest item is always overwritten next.
Import path: torchrl.data.RoundRobinWriter
Enable
torch.compile compatibility.TensorDictRoundRobinWriter
TensorDict-aware version of RoundRobinWriter. The default writer for TensorDictReplayBuffer. Supports writing nested TensorDict keys correctly.
Import path: torchrl.data.TensorDictRoundRobinWriter
TensorDictMaxValueWriter
Inserts a new item only if its priority value exceeds the current minimum in the buffer. Useful for maintaining a “best-of” buffer of high-quality experiences.
Import path: torchrl.data.TensorDictMaxValueWriter
Key containing the scalar priority used for comparison. Defaults to
"reward".Other Writer Classes
| Class | Description |
|---|---|
ImmutableDatasetWriter | Raises an error on any write attempt — used for read-only offline datasets. |
WriterEnsemble | Dispatches writes across multiple storages. |
Complete Training Loop Example
The following example shows a typical online RL training loop with aTensorDictReplayBuffer and Prioritized Experience Replay.
Checkpointing
All replay buffers support save/restore viadumps and loads: