Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Every network in TorchRL is a standard PyTorch nn.Module wrapped with an explicit declaration of which TensorDict keys it reads and which it writes. This key contract makes the data flow of the entire pipeline visible at construction time, lets components be reconfigured without editing network code, and enables collectors, replay buffers, and loss modules to compose with any policy without knowing its architecture. The tensordict.nn package provides TensorDictModule, TensorDictSequential, and TensorDictModuleBase as the building blocks; TorchRL layers such as ProbabilisticActor and ValueOperator inherit from these.

TensorDictModule: explicit key contracts

TensorDictModule wraps any nn.Module with in_keys and out_keys. It reads the listed keys from an input TensorDict, calls the wrapped module’s forward, and writes the outputs back under the out_keys.
from tensordict.nn import TensorDictModule
from torch import nn

# A simple MLP that reads "observation" and writes "action".
net = TensorDictModule(
    nn.Sequential(nn.LazyLinear(256), nn.Tanh(), nn.Linear(256, 2)),
    in_keys=["observation"],
    out_keys=["action"],
)

# Forward: td["action"] is populated in-place and returned.
td = net(td)
print(td["action"].shape)  # [B, 2]
Nested keys work as both in_keys and out_keys:
agent_net = TensorDictModule(
    nn.LazyLinear(64),
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "action_logits")],
)

ProbabilisticActor: stochastic policies

ProbabilisticActor combines a parameter network with a distribution class to produce stochastic actions. It is the standard way to build actors for PPO, SAC, REINFORCE, and any other algorithm that needs log-probabilities.
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.modules import ProbabilisticActor, TanhNormal

# Step 1: a network that produces distribution parameters.
params_net = TensorDictModule(
    nn.Sequential(
        nn.LazyLinear(256),
        nn.Tanh(),
        nn.Linear(256, 2),          # outputs 2 values: loc and log_scale
        NormalParamExtractor(),     # splits into "loc" and "scale"
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

# Step 2: wrap with ProbabilisticActor.
actor = ProbabilisticActor(
    params_net,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=TanhNormal,
    distribution_kwargs={"low": -1.0, "high": 1.0},
    return_log_prob=True,   # writes "sample_log_prob" to the TensorDict
)

# Forward: samples an action and writes log_prob.
td = actor(td)
print(td["action"].shape)          # [B, action_dim]
print(td["sample_log_prob"].shape) # [B]

Available distributions

ClassUse case
TanhNormalContinuous actions in a bounded range (SAC, TD3, PPO)
IndependentNormalUnbounded continuous actions
TruncatedNormalBounded normal with proper gradient through the boundary
TanhDeltaDeterministic policy wrapped in a TanhNormal for SAC with no noise
OneHotCategoricalDiscrete actions (one-hot encoded)
MaskedCategoricalDiscrete actions with an action mask
MaskedOneHotCategoricalOne-hot discrete with masking
DeltaDeterministic action (Dirac delta)
NormalParamExtractor splits the last output dimension in half and applies softplus to the second half to produce a positive scale. This avoids having to build two separate output heads manually.

ValueOperator: critic networks

ValueOperator wraps a critic nn.Module with the conventional key contract for value functions. By default it reads "observation" and writes "state_value".
from torchrl.modules import ValueOperator
from torch import nn

critic = ValueOperator(
    nn.Sequential(
        nn.LazyLinear(256),
        nn.Tanh(),
        nn.Linear(256, 1),
    ),
    in_keys=["observation"],
    out_keys=["state_value"],
)

td = critic(td)
print(td["state_value"].shape)  # [B, 1]
For Q-value critics that also consume the action:
qvalue_net = ValueOperator(
    nn.Sequential(nn.LazyLinear(256), nn.Tanh(), nn.Linear(256, 1)),
    in_keys=["observation", "action"],
    out_keys=["state_action_value"],
)

ActorCriticWrapper and ActorValueOperator

For algorithms that share parameters between actor and critic (e.g., A2C), TorchRL provides helper wrappers.
from torchrl.modules import ActorCriticWrapper

# Independent actor and critic — no shared parameters.
actor_critic = ActorCriticWrapper(actor, critic)
td = actor_critic(td)
# td["action"], td["sample_log_prob"], td["state_value"] are all written.

Model builders: MLP, ConvNet, and friends

TorchRL ships opinionated model builders that handle common architectural patterns so you can focus on algorithm design.

MLP

from torchrl.modules import MLP

net = MLP(
    in_features=8,
    out_features=2,
    num_cells=[256, 256],
    activation_class=nn.Tanh,
    activate_last_layer=False,
)

ConvNet

from torchrl.modules import ConvNet

cnn = ConvNet(
    num_cells=[32, 64, 64],
    kernel_sizes=[8, 4, 3],
    strides=[4, 2, 1],
)

MultiAgentMLP

from torchrl.modules import MultiAgentMLP

net = MultiAgentMLP(
    n_agent_inputs=8,
    n_agent_outputs=2,
    n_agents=4,
    centralised=False,
    share_params=True,
    num_cells=[256, 256],
)

NoisyLinear

from torchrl.modules import NoisyLinear

# Noisy networks for intrinsic exploration.
layer = NoisyLinear(256, 256, std_init=0.5)

DDPG-style actor / critic nets

from torchrl.modules import DdpgMlpActor, DdpgMlpQNet

actor_net = DdpgMlpActor(
    action_dim=2,
    mlp_net_kwargs={"num_cells": [400, 300]},
)
qvalue_net = DdpgMlpQNet(
    mlp_net_kwargs={"num_cells": [400, 300]},
)

Recurrent modules

TorchRL provides GRUModule and LSTMModule as TensorDictModuleBase subclasses. They read and write hidden states by name so they are compatible with replay buffers and collectors without any special handling.
from torchrl.modules import GRUModule

gru = GRUModule(
    input_size=64,
    hidden_size=128,
    in_key="embed",
    out_key="hidden",
    default_recurrent_mode=True,
)

Exploration modules

Exploration modules are composable wrappers that inject noise or randomness at data-collection time and are automatically disabled during evaluation (via set_exploration_type).
from torchrl.modules import EGreedyModule

# ε-greedy for discrete actions.
explorer = EGreedyModule(
    action_space=env.action_spec,
    annealing_num_steps=100_000,
    eps_init=1.0,
    eps_end=0.05,
)

policy_explore = TensorDictSequential(qvalue_actor, explorer)
Call set_exploration_type(ExplorationType.DETERMINISTIC) (or MEAN / MODE) before evaluation to disable all noise modules. The exploration mode is propagated to every module in the policy tree automatically.

Value normalization

TorchRL provides three value-normalization utilities that stabilize training by adaptively rescaling critic targets.
from torchrl.modules import PopArtValueNorm, RunningValueNorm, ValueNorm

# PopArt: normalizes targets and rescales last critic layer weights.
pop_art = PopArtValueNorm(in_keys=["reward"], beta=0.001)

# RunningValueNorm: exponential moving average normalization.
running_norm = RunningValueNorm(in_keys=["reward"], beta=0.001)

# ValueNorm: simple running mean/variance normalization.
value_norm = ValueNorm(in_keys=["state_value"])

Distributional Q-value actors

For distributional RL (C51, QR-DQN), TorchRL provides DistributionalQValueActor and DistributionalQValueModule.
from torchrl.modules import DistributionalQValueActor, MLP

# The network outputs a distribution over return atoms.
q_net = MLP(in_features=8, out_features=n_actions * n_atoms, num_cells=[256])
actor = DistributionalQValueActor(
    q_net,
    in_keys=["observation"],
    out_keys=["action"],
    support=torch.linspace(-10, 10, n_atoms),
)

Complete stochastic actor example

The following builds a full stochastic actor ready for SAC, combining an MLP backbone, NormalParamExtractor, ProbabilisticActor, and TanhNormal.
import torch
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator

obs_dim = 17   # e.g., HalfCheetah observation
action_dim = 6

# Actor: observation → (loc, scale) → sample TanhNormal action
param_net = TensorDictModule(
    nn.Sequential(
        MLP(in_features=obs_dim, out_features=2 * action_dim, num_cells=[256, 256]),
        NormalParamExtractor(),
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

actor = ProbabilisticActor(
    param_net,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=TanhNormal,
    distribution_kwargs={"low": -1.0, "high": 1.0},
    return_log_prob=True,   # needed by SACLoss
)

# Critic: (observation, action) → state-action value
qvalue = ValueOperator(
    MLP(
        in_features=obs_dim + action_dim,
        out_features=1,
        num_cells=[256, 256],
    ),
    in_keys=["observation", "action"],
    out_keys=["state_action_value"],
)

# Verify shapes.
td = env.reset()
td = actor(td)
print("action:", td["action"].shape)           # [action_dim]
print("log_prob:", td["sample_log_prob"].shape) # []
td = qvalue(td)
print("q-value:", td["state_action_value"].shape)  # [1]

Build docs developers (and LLMs) love