Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL is built on a single unifying idea: every piece of data in the training loop — observations, actions, rewards, recurrent states, priorities, agent groupings — lives inside a TensorDict. TensorDict is a dictionary-like tensor container that supports PyTorch operations, device transfers, shared-memory storage, memmaps, lazy views, and nn.Module wrappers. Rather than passing parallel Python lists or positional tuples between components, TorchRL threads one structured object through the entire pipeline so that environments, collectors, replay buffers, and loss modules can all consume and produce the same type without any glue code.

What is TensorDict?

A TensorDict is essentially a dict[str | tuple[str, ...], Tensor] that knows its own batch dimensions and device. Every value shares the same leading batch shape; individual tensors may have additional trailing dimensions. The container ships with a full suite of PyTorch-like operations so that code that used to work on raw tensors works on TensorDicts with almost no changes.
import torch
from tensordict import TensorDict

# Build a TensorDict with batch size [32].
td = TensorDict(
    {
        "observation": torch.randn(32, 8),
        "action": torch.randn(32, 2),
        "reward": torch.randn(32, 1),
    },
    batch_size=[32],
)

print(td.batch_size)    # torch.Size([32])
print(td["observation"].shape)  # torch.Size([32, 8])
TensorDict is a separate PyTorch library (tensordict) that TorchRL depends on. It is maintained at github.com/pytorch/tensordict and can be used independently of TorchRL.

Key operations

Because TensorDict mirrors the PyTorch tensor API, all of the following operations preserve the internal field structure and batch dimensions automatically.
# Stack a list of TensorDicts along a new dimension.
batch = torch.stack(list_of_tensordicts, dim=0)

# Reshape the batch dimensions.
batch = batch.reshape(-1)

# Move every tensor to a device in one call.
batch = batch.to("cuda")

# Standard indexing slices the batch dimension across all fields at once.
mini_batch = batch[:128]
All these operations are differentiable through any tensor that has requires_grad=True. You can keep log-probabilities or value estimates inside the same TensorDict and call .backward() on a scalar derived from any of them.

Nested keys and structured data

TorchRL uses nested keys — tuples of strings — to represent structured sub-fields. This single convention handles multi-agent data, recurrent hidden states, and next-step observations without any schema changes or special-casing in component code.
# Next-state data lives under the ("next", ...) prefix by convention.
reward   = batch["next", "reward"]          # shape: [B, 1]
next_obs = batch["next", "observation"]     # shape: [B, obs_dim]

# Multi-agent observations sit under an agent-group key.
agent_obs = batch["agents", "observation"]  # shape: [B, n_agents, obs_dim]

# Recurrent hidden states are just another nested entry.
h = batch["recurrent_state", "h"]          # shape: [B, hidden_dim]
TorchRL environments write ("next", "observation"), ("next", "reward"), and ("next", "done") as a matter of convention. Loss modules read those same keys. Nothing needs to be told what shape is coming — the TensorDict carries it.

next state convention

Next observations and rewards always live under the "next" sub-key, making multi-step transitions, value bootstrapping, and n-step returns unambiguous.

agent grouping

Multi-agent environments place per-agent data under a group key such as "agents". Losses and modules can target the right sub-tree without changing any other code.

recurrent states

Hidden states for GRUs and LSTMs are stored by name so they survive replay buffer round-trips and can be properly zero-initialised at episode starts.

custom fields

Any algorithm can attach task-specific tensors (e.g. "advantage", "td_error", "goal") and they flow through unchanged unless a transform or loss removes them.

TensorDict as the composability backbone

The reason TorchRL components compose so naturally is that each one only cares about the keys it declared — everything else is passed through untouched. The pipeline looks like this:
TensorDict
  -> policy module    writes "action", "log_prob"
  -> environment      reads "action", writes ("next", "observation"), ("next", "reward"), ("next", "done")
  -> collector        batches transitions, appends device / batch metadata
  -> replay buffer    stores, samples, prioritizes (reads "td_error")
  -> loss module      reads ("next", "reward"), "advantage", "log_prob" etc., writes "loss_*" scalars
  -> optimizer        updates ordinary PyTorch parameters
No component needs to know about the others. A collector can emit a TensorDict, the replay buffer stores it without losing structure, a transform can add or remove keys, and a loss reads exactly the keys it needs.

A complete rollout example

The following snippet is taken directly from the TorchRL README. It shows a full rollout from a TransformedEnv into a batched TensorDict — no unpacking required.
import torch
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.envs import PendulumEnv, StepCounter, TransformedEnv

# A PyTorch-native environment with a transform stack.
env = TransformedEnv(PendulumEnv(), StepCounter(max_steps=200))

# Policies are regular nn.Modules wrapped with explicit TensorDict keys.
policy = TensorDictModule(
    nn.Sequential(
        nn.LazyLinear(64),
        nn.Tanh(),
        nn.Linear(64, 1),
        nn.Tanh(),
    ),
    in_keys=["observation"],
    out_keys=["action"],
)

rollout = env.rollout(max_steps=32, policy=policy)

assert rollout.batch_size == torch.Size([32])
assert rollout["next", "reward"].shape[:1] == torch.Size([32])
rollout is a single TensorDict with batch size [32]. Every step’s observation, action, reward, done flag, and step count are aligned at the same index. You can index, slice, or stack it just like a tensor.

Batch operations on collected data

When a collector returns data, the same tensor-level operations work on the entire TensorDict:
# Collected batch: shape [1024] (frames_per_batch).
batch = torch.stack(list_of_tensordicts, dim=0)
batch = batch.reshape(-1)
batch = batch.to("cuda")

# Pull out specific fields.
reward = batch["next", "reward"]
agent_obs = batch["agents", "observation"]
hidden = batch["recurrent_state", "h"]

# Mini-batch indexing.
for i in range(0, len(batch), 128):
    mini = batch[i : i + 128]
    # mini is still a TensorDict — no re-packaging needed.
TensorDict indexing always operates on the batch dimensions. If your TensorDict has batch_size=[T, B], then td[0] returns a TensorDict with batch_size=[B], not a raw tensor.

TensorDictModule: modules with explicit key contracts

TensorDictModule wraps any nn.Module with explicit in_keys and out_keys. This makes the data contract of every network layer visible at construction time rather than buried in a forward signature.
from tensordict.nn import TensorDictModule
from torch import nn

# A critic that reads "observation" and writes "state_value".
critic_net = TensorDictModule(
    nn.Sequential(nn.LazyLinear(256), nn.Tanh(), nn.Linear(256, 1)),
    in_keys=["observation"],
    out_keys=["state_value"],
)

# Forward pass receives and returns a TensorDict.
td = critic_net(td)
print(td["state_value"].shape)  # [B, 1]
All TorchRL modules — ProbabilisticActor, ValueOperator, loss modules — follow this same pattern. See the Modules & Policies page for details.

Build docs developers (and LLMs) love