Use this file to discover all available pages before exploring further.
TorchRL replay buffers are built around a simple principle: storage, sampling, writing, and post-sampling transforms are independent, swappable components. You can combine any storage backend with any sampler, attach transforms that run at sample time, enable multi-threaded prefetching, or make the buffer remote and distributed — all without changing the calling code. The same extend / sample API works for simple in-memory replay, disk-backed memmap storage, prioritized replay, hindsight experience relabeling, and offline dataset loading.
ReplayBuffer ├── Storage — where tensors are physically kept (RAM, mmap, GPU) ├── Sampler — how indices are chosen (uniform, prioritized, slice) ├── Writer — how new data enters the buffer (round-robin, max-value) └── Transform — post-processing applied to every sampled batch
Each piece is instantiated independently and composed at construction time. TorchRL ships defaults for all four so you only need to override the parts you care about.
TensorDictReplayBuffer is the TensorDict-aware wrapper you should use in almost all TorchRL workflows. It understands nested keys, preserves structure across extend / sample round-trips, and integrates with transforms that operate on TensorDicts.
from torchrl.data import TensorDictReplayBuffer, LazyTensorStoragebuffer = TensorDictReplayBuffer( storage=LazyTensorStorage(max_size=100_000), batch_size=256, prefetch=4, # background threads prefetch next batches)# Extend from a collector batch (TensorDict).buffer.extend(collector_batch)# Sample returns a TensorDict with structure preserved.sample = buffer.sample()print(sample["observation"].shape) # [256, obs_dim]print(sample["next", "reward"].shape) # [256, 1]
LazyTensorStorage infers tensor shapes and dtypes from the first extend call. You never need to pre-declare the schema — the buffer adapts to whatever TensorDict structure it receives.
TensorDictPrioritizedReplayBuffer weights samples by a priority stored under "td_error" (configurable). Losses that compute TD error write it back to the TensorDict; the buffer uses update_tensordict_priority() to register the new weights.
from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBufferbuffer = TensorDictPrioritizedReplayBuffer( storage=LazyMemmapStorage(1_000_000), alpha=0.7, # prioritization exponent (0 = uniform) beta=0.5, # importance-sampling exponent eps=1e-6, # added to priorities to avoid zeros batch_size=256, prefetch=2,)buffer.extend(collector_batch)sample = buffer.sample()# After computing TD errors, update priorities.sample["td_error"] = td_errors.abs().detach()buffer.update_tensordict_priority(sample)
The returned sample contains an "index" key that update_tensordict_priority uses to write updated priorities back to the correct tree positions.
from torchrl.data import ListStorage# Stores data as a Python list of objects.# Flexible but not optimized for large batches or GPU workflows.storage = ListStorage(max_size=10_000)
Best for: small buffers, non-tensor data, heterogeneous batch shapes.
from torchrl.data import LazyTensorStorage# Allocates a contiguous tensor per field on first extend().# Fast random-access; lives in RAM (CPU or CUDA).storage = LazyTensorStorage(max_size=100_000, device="cuda:0")
Best for: standard on-policy or off-policy RL with fixed observation shapes.
from torchrl.data import LazyMemmapStorage# Memory-maps tensors to disk.# Can exceed RAM, supports asynchronous I/O, and is safe# to share across processes via shared memory.storage = LazyMemmapStorage( max_size=5_000_000, scratch_dir="/tmp/replay", device="cpu",)
Best for: large offline datasets, distributed jobs, multi-process collectors.
from torchrl.data import LazyStackStorage# Stores multi-step trajectories as stacked TensorDicts.# Useful when the sampler needs to slice along the time axis.storage = LazyStackStorage(max_size=50_000)
Best for: recurrent RL, sequence-level sampling with SliceSampler.
Combines priority-based sampling with trajectory slicing — each slice is sampled from an episode with probability proportional to the maximum TD error in that episode.
HERReplayBuffer applies goal relabeling at sample time. For a configurable fraction of each batch, the desired goal is replaced with an achieved goal from the same episode, and the reward is recomputed via a user-supplied function. The relabeling strategy can be "future", "final", "episode", or "random".
Transforms can be attached to a replay buffer and execute at sample() time. This is useful for data augmentation, observation normalization that depends on the sampled batch, or on-the-fly device transfers.
Transforms attached to a replay buffer run on the sampled mini-batch, not at insert time. This means the normalization statistics can be updated between samples without re-writing the buffer.
TorchRL supports loading and sampling from offline datasets using the same buffer interface. The ImmutableDatasetWriter prevents any further writes after loading.
from torchrl.data import ( TensorDictReplayBuffer, LazyMemmapStorage, ImmutableDatasetWriter, SamplerWithoutReplacement,)# Load a pre-collected offline dataset from a previously saved buffer.storage = LazyMemmapStorage(max_size=1_000_000)storage.loads("/path/to/dataset") # loads checkpoint into storage in-placedataset = TensorDictReplayBuffer( storage=storage, sampler=SamplerWithoutReplacement(drop_last=True), writer=ImmutableDatasetWriter(), batch_size=256,)for batch in dataset: offline_loss = loss_fn(batch) optimizer.step()
Setting prefetch > 0 launches background threads that pre-load the next prefetch batches while your training code processes the current one. For distributed setups, RayReplayBuffer and RemoteTensorDictReplayBuffer expose the same extend / sample API across Ray actors or RPC processes.