Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL’s actor and critic modules are thin wrappers around ordinary nn.Module objects that give them a TensorDict interface: each module declares which keys it reads from and writes to a TensorDict, making policy construction composable and data-flow explicit. Rather than wiring raw tensor arguments through function calls, you describe your policy as a sequence of named transformations on a shared dictionary. This page covers all the building blocks — from the base TensorDictModule through stochastic actors, value operators, and combined actor-critic architectures.

TensorDictModule and SafeModule

Every TorchRL module is rooted in TensorDictModule from the tensordict library. SafeModule (torchrl.modules.SafeModule) extends it with an optional output-validation step: when safe=True, the module calls spec.project() on any out-of-bounds sample before writing it to the TensorDict.
from torchrl.modules import SafeModule
from torchrl.data import Unbounded
import torch
from tensordict import TensorDict

module = torch.nn.GRUCell(4, 8)
td_module = SafeModule(
    module=module,
    in_keys=["input", "hidden"],
    out_keys=["output"],
    spec=Unbounded(8),
    safe=False,
)
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
td_out = td_module(td)
print(td_out["output"].shape)  # torch.Size([3, 8])
SafeModule is a subclass of tensordict.nn.TensorDictModule. You can use either class; SafeModule simply adds the spec and safe keyword arguments.

Actor

Actor is a convenience subclass of SafeModule for deterministic policies. It sets default keys in_keys=["observation"] and out_keys=["action"] and, if a non-Composite spec is given for the action, wraps it automatically as Composite(action=spec).
from torchrl.modules import Actor
from torchrl.data import Unbounded
import torch
from tensordict import TensorDict

action_spec = Unbounded(4)
module = torch.nn.Linear(4, 4)
actor = Actor(module=module, spec=action_spec)

td = TensorDict({"observation": torch.randn(3, 4)}, [3])
print(actor(td)["action"].shape)  # torch.Size([3, 4])

ProbabilisticActor

ProbabilisticActor is the stochastic-policy workhorse in TorchRL. It wraps a TensorDictModule backbone together with a distribution class, sampling an action from the distribution and (optionally) writing the log-probability back to the TensorDict. The constructor accepts two distinct module patterns:
  • Split backbone: A TensorDictModule that writes distribution parameters to dedicated keys (e.g. "loc" and "scale"), followed by sampling inside ProbabilisticActor.
  • Composite distribution: A single TensorDictModule writing to a nested "params" key, paired with CompositeDistribution.
module
TensorDictModule
required
The backbone module that computes distribution parameters and writes them to the TensorDict.
in_keys
NestedKey | Sequence[NestedKey] | dict
required
Keys to read from the TensorDict as distribution inputs. Must match constructor keyword names of distribution_class (e.g. "loc", "scale" for Normal). When given as a dict, keys are distribution parameter names and values are the TensorDict keys that supply them.
out_keys
Sequence[NestedKey]
default:"[\"action\"]"
Keys where sampled values are written. Skips sampling if these keys already exist in the input TensorDict.
spec
TensorSpec
Output spec for the first sampled tensor. Non-Composite specs are automatically wrapped as Composite(action=spec).
distribution_class
type[Distribution]
default:"Delta"
A torch.distributions.Distribution subclass used for sampling. Common choices: TanhNormal, MaskedCategorical, Normal, CompositeDistribution.
distribution_kwargs
dict
Extra keyword arguments forwarded to distribution_class at construction time. For example {"low": -1.0, "high": 1.0} when using TanhNormal.
return_log_prob
bool
default:"False"
When True, writes sample_log_prob (the log-probability of the sampled action) into the output TensorDict.
default_interaction_type
InteractionType
default:"InteractionType.DETERMINISTIC"
Fallback interaction mode when the global interaction_type() returns None. Options: DETERMINISTIC, RANDOM, MODE, MEAN, MEDIAN. Collectors override this to RANDOM automatically.
cache_dist
bool
default:"False"
Experimental. When True, writes the distribution parameters to the TensorDict so the original distribution can be reconstructed later (useful for PPO’s KL / ratio computation).
generator
torch.Generator | int | NestedKey | None
default:"None"
Routes sampling through an explicit RNG instead of the global PyTorch RNG. Pass an int as a shorthand for Generator().manual_seed(int), or a NestedKey to read the generator from the TensorDict on each forward call.

Example: stochastic actor with TanhNormal

import torch
from torch import nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.data import Bounded
from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal

# 1. Backbone: maps observation -> (loc, scale)
backbone = nn.Sequential(nn.Linear(4, 8), NormalParamExtractor())
td_backbone = TensorDictModule(
    backbone,
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

# 2. Stochastic actor: samples action ~ TanhNormal(loc, scale)
action_spec = Bounded(shape=torch.Size([4]), low=-1, high=1)
actor = ProbabilisticActor(
    module=td_backbone,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    spec=action_spec,
    distribution_class=TanhNormal,
    return_log_prob=True,
)

td = TensorDict({"observation": torch.randn(3, 4)}, [3])
td = actor(td)
print(td["action"].shape)          # torch.Size([3, 4])
print(td["sample_log_prob"].shape) # torch.Size([3])
Use NormalParamExtractor (from tensordict.nn) to split the output of a linear layer into loc and scale halves. It applies a softplus transformation to the scale to keep it positive.

ValueOperator

ValueOperator wraps a value-function network with sensible default keys. When "action" is present in in_keys, it defaults to out_keys=["state_action_value"] (for Q-functions); otherwise it defaults to out_keys=["state_value"] (for state-value functions V(s)).
module
nn.Module
required
A neural network that computes value estimates.
in_keys
Sequence[NestedKey]
default:"[\"observation\"]"
Keys read from the TensorDict. Include "action" to build a Q-function.
out_keys
Sequence[NestedKey]
default:"[\"state_value\"] or [\"state_action_value\"]"
Keys written to the TensorDict. Auto-selected based on whether "action" is in in_keys.
import torch
from torch import nn
from tensordict import TensorDict
from torchrl.modules import ValueOperator

# State-value function V(s)
critic = ValueOperator(
    module=nn.Linear(4, 1),
    in_keys=["observation"],
)
td = TensorDict({"observation": torch.randn(3, 4)}, [3])
critic(td)
print(td["state_value"].shape)  # torch.Size([3, 1])

# Q-function Q(s, a)
q_critic = ValueOperator(
    module=nn.Linear(6, 1),
    in_keys=["observation", "action"],
)

Combined Actor-Critic Architectures

ActorValueOperator

ActorValueOperator composes three sub-modules that share a common observation encoder: a common_operator that produces a hidden state, a policy_operator that turns the hidden state into an action, and a value_operator that turns it into a value estimate. Use get_policy_operator() and get_value_operator() to extract standalone operators for collection and loss computation.
from torchrl.modules import (
    ActorValueOperator, ProbabilisticActor, ValueOperator,
    SafeModule, TanhNormal, NormalParamExtractor,
)
from tensordict.nn import TensorDictModule
import torch
from torch import nn

common = SafeModule(nn.Linear(4, 4), in_keys=["observation"], out_keys=["hidden"])

policy_net = TensorDictModule(
    nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()),
    in_keys=["hidden"], out_keys=["loc", "scale"],
)
actor = ProbabilisticActor(
    module=policy_net,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    return_log_prob=True,
)
critic = ValueOperator(module=nn.Linear(4, 1), in_keys=["hidden"])

actor_critic = ActorValueOperator(common, actor, critic)

# Separate operators for collection / loss:
policy_op = actor_critic.get_policy_operator()
value_op  = actor_critic.get_value_operator()

ActorCriticOperator

ActorCriticOperator is like ActorValueOperator but wires the action into the critic, producing Q(s, a) instead of V(s). The critic receives both the hidden state and the action produced by the policy.

ActorCriticWrapper

ActorCriticWrapper bundles an actor and a critic that do not share parameters. It accepts any two TensorDictModule objects and exposes the same get_policy_operator() / get_value_operator() interface.
from torchrl.modules import ActorCriticWrapper, Actor, ValueOperator
import torch
from torch import nn

actor = Actor(module=nn.Linear(4, 4))
critic = ValueOperator(module=nn.Linear(4, 1))
wrapper = ActorCriticWrapper(actor, critic)

Discrete Action Policies

QValueActor and QValueModule

QValueActor converts raw action-value logits into a greedy action. It wraps a backbone module (which outputs action values) with a QValueModule that applies argmax and writes the selected action into the TensorDict.
from torchrl.modules import QValueActor
from torchrl.data import OneHot
import torch
from torch import nn

spec = OneHot(n=4)
net = nn.Linear(3, 4)
actor = QValueActor(module=net, spec=spec, in_keys=["observation"])

DistributionalQValueActor

DistributionalQValueActor implements distributional RL (C51): the network outputs a distribution over returns for each action. The module applies a softmax over the return atoms and computes the expected Q-values to select the greedy action.

Model Builders

MLP

MLP is a flexible multi-layer perceptron builder that inherits from nn.Sequential. It supports lazy input inference (no in_features needed), per-layer normalization, and dropout.
in_features
int
Input feature dimension. If omitted, uses LazyLinear for the first layer.
out_features
int | torch.Size
required
Output feature dimension. If a torch.Size, the output is reshaped to that shape.
depth
int
Number of hidden layers. depth=0 produces a single linear layer; depth=N produces N+1 linear layers. Defaults to None, which derives depth from the length of num_cells (or 0 if num_cells is also omitted).
num_cells
int | Sequence[int]
default:"32"
Width of each hidden layer. Can be a list to set per-layer widths; the list length must equal depth.
activation_class
type[nn.Module]
default:"nn.Tanh"
Activation applied after every hidden layer.
dropout
float
Dropout probability applied after each activation. Omit for no dropout.
layer_class
type[nn.Module]
default:"nn.Linear"
Linear layer class. Pass NoisyLinear for noisy networks.
from torchrl.modules import MLP

# 3-layer network: 8 → 32 → 32 → 6
mlp = MLP(in_features=8, out_features=6, depth=2, num_cells=32)

# Lazy input inference + per-layer widths
mlp = MLP(out_features=6, num_cells=[64, 64, 32])

ConvNet

ConvNet is a configurable 2-D convolutional network builder with a terminal SquashDims aggregator that flattens spatial dimensions.
in_features
int
Input channel count. Omit for lazy initialization.
num_cells
int | Sequence[int]
default:"[32, 32, 32]"
Output channel counts per convolutional layer.
kernel_sizes
int | Sequence[int]
default:"3"
Kernel sizes per layer. Can be rectangular tuples, e.g. (2, 3).
strides
int | Sequence[int]
default:"1"
Stride per layer.
activation_class
type[nn.Module]
default:"nn.ELU"
Activation after each convolutional layer.
from torchrl.modules import ConvNet

cnet = ConvNet(in_features=3, num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1])

Value Normalization

Value normalization stabilizes critic training by keeping value targets on a fixed scale throughout training.

ValueNorm (abstract base)

ValueNorm defines a common interface with three abstract methods:
  • update(value_target) — fold a batch of targets into the running statistics.
  • normalize(value_target) — standardize using current statistics.
  • denormalize(normalised_value) — invert the normalization.

PopArtValueNorm

Exponentially-weighted moving-average normalizer (van Hasselt et al., AAAI 2019). Uses a debiasing term so early estimates are unbiased. Recommended for multi-task or curriculum settings where the reward scale can drift.
from torchrl.modules import PopArtValueNorm
import torch

vn = PopArtValueNorm(shape=1, beta=0.99999)
target = torch.randn(64, 1) * 5.0 + 2.0
for _ in range(100):
    vn.update(target)
normed = vn.normalize(target)        # ≈ N(0, 1)
recovered = vn.denormalize(normed)   # ≈ original scale

RunningValueNorm

Welford exact running mean and variance with no exponential decay. Cheaper and more stable than PopArtValueNorm when value targets are stationary. Good default for single-task, fixed-reward-scale runs.
from torchrl.modules import PopArtValueNorm
vn = PopArtValueNorm(shape=1, beta=0.99999, epsilon=1e-5)
Always call vn.update(targets) before vn.normalize(targets) in each training step. Forgetting to call update will freeze the running statistics and degrade performance silently.

Build docs developers (and LLMs) love