Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Recurrent policies are essential whenever the environment is partially observed and a single frame is not enough to determine the optimal action. Standard PyTorch nn.LSTM and nn.GRU modules require you to thread hidden states manually between calls, which conflicts with TorchRL’s TensorDict-based data flow. TorchRL solves this with LSTMModule and GRUModule: they store hidden states as ordinary TensorDict keys under well-defined names, making recurrent state part of the same structured dictionary that carries observations, actions, and rewards. Collectors automatically carry hidden states forward and reset them at episode boundaries.

Why Hidden States Need Special Handling

In a standard collector loop, each step calls policy(tensordict) and the tensordict flows through the whole pipeline — transforms, replay buffers, samplers. If the hidden state is just a Python variable held outside the tensordict, it gets silently lost when the tensordict is sliced, shuffled, or sent across processes. By placing hidden states in the tensordict under keys like ("next", "recurrent_state_h"), TorchRL ensures they are:
  • Automatically carried forward: each step reads recurrent_state_h and writes ("next", "recurrent_state_h"), so the StepCounter / RolloutWrapper infrastructure knows to copy them.
  • Reset at episode boundaries: InitTracker and TensorDictPrimer transforms zero them when is_init=True.
  • Batch-compatible: they live inside the same batched tensordict as observations, so they are correctly sliced for sub-batches and padded for variable-length sequences.

LSTMModule

LSTMModule wraps torch.nn.LSTM (or TorchRL’s Python-native LSTM implementation) with TensorDict-compatible I/O. It has two modes of operation:
  • Single-step mode (default): processes one time step at a time, updating hidden states in-place. Used during environment collection.
  • Recurrent mode: processes a full time-sequence of shape [B, T, *], enabling truncated BPTT during training. Enabled with set_recurrent_mode(True) or the recurrent_mode context manager.
input_size
int
required
Number of expected input features (observation / embedding dimensionality).
hidden_size
int
required
Number of features in the LSTM hidden state h (and cell state c).
num_layers
int
default:"1"
Number of stacked LSTM layers.
bias
bool
default:"True"
Whether to include bias terms b_ih and b_hh.
dropout
float
default:"0"
Dropout probability on outputs of each LSTM layer except the last.
python_based
bool
default:"False"
When True, uses TorchRL’s fully Python-implemented LSTMCell instead of the cuDNN kernel. Required for torch.vmap and torch.compile.
recurrent_backend
str
default:"\"pad\""
Backend used when trajectories reset mid-batch. Options:
  • "pad" — splits trajectories and pads to uniform length (default).
  • "scan" — uses a scan loop via hoptorch; avoids materialization of padded chunks.
  • "triton" — prototype Triton kernels (CUDA only, requires Triton ≥ 2.2).
  • "auto" — uses "pad" in eager mode and "scan" under torch.compile.
in_key
str | tuple[str]
Shorthand for in_keys when hidden state key names follow the default convention. Exclusive with in_keys.
in_keys
list[str]
A triplet [input_key, hidden_h_key, hidden_c_key] specifying what to read from the input TensorDict. Exclusive with in_key.
out_key
str | tuple[str]
Shorthand for out_keys. Exclusive with out_keys.
out_keys
list[str]
A triplet [output_key, next_hidden_h_key, next_hidden_c_key]. For correct rollout behavior, hidden output keys should be nested under "next", e.g. [("next", "rs_h"), ("next", "rs_c")].
default_recurrent_mode
bool
Default value for recurrent_mode when no context manager is active. Defaults to False.

Hidden State Key Convention

TorchRL uses the "next" nesting convention to propagate state between steps. The pattern is:
TensorDict at step t:
  "rs_h"          ← current hidden h (shape [num_layers, hidden_size])
  "rs_c"          ← current cell   c
  "next" / "rs_h" ← updated hidden h after applying LSTM
  "next" / "rs_c" ← updated cell   c
The StepCountTransform (or RolloutWrapper) copies td["next"]["rs_h"]td["rs_h"] automatically before the next step.

set_recurrent_mode

set_recurrent_mode (from torchrl.modules) switches between single-step and sequence-processing behavior. Use it as a context manager during training or set the default via the default_recurrent_mode constructor argument:
from torchrl.modules import set_recurrent_mode

# Context manager (preferred for training)
with set_recurrent_mode():
    output = lstm_module(trajectory_td)  # td has shape [B, T]

# Set default recurrent mode at construction time
lstm_module = LSTMModule(
    input_size=64, hidden_size=64,
    in_key="x", out_key="y",
    default_recurrent_mode=True,
)
The instance method lstm_module.set_recurrent_mode() was removed in TorchRL v0.8 and now raises RuntimeError. Use the set_recurrent_mode context manager from torchrl.modules or the default_recurrent_mode constructor argument instead.

make_tensordict_primer

LSTMModule.make_tensordict_primer() returns a TensorDictPrimer transform that initializes hidden-state keys with zeros in the environment’s observation tensordict. Apply it to the environment via TransformedEnv.
from torchrl.envs import TransformedEnv, InitTracker, GymEnv

env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
env.append_transform(lstm_module.make_tensordict_primer())

Full Example: LSTM-Based Actor

import torch
from torch import nn
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
from torchrl.envs import TransformedEnv, InitTracker, GymEnv
from torchrl.modules import MLP, LSTMModule, set_recurrent_mode

# Environment
env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
obs_size = env.observation_spec["observation"].shape[-1]  # 3

# LSTM encoder
lstm = LSTMModule(
    input_size=obs_size,
    hidden_size=64,
    num_layers=1,
    in_keys=["observation", "rs_h", "rs_c"],
    out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
)

# MLP head
head = MLP(in_features=64, out_features=1, depth=1, num_cells=64)
head_mod = Mod(head, in_keys=["intermediate"], out_keys=["action"])

# Combined policy
policy = Seq(lstm, head_mod)

# Add hidden-state priming to the environment
env.append_transform(lstm.make_tensordict_primer())

# Single-step collection (default)
td = env.reset()
td = policy(td)
print(td["action"].shape)          # torch.Size([1])
print(td["next", "rs_h"].shape)    # torch.Size([1, 64])

# Multi-step training (recurrent mode)
with set_recurrent_mode():
    traj = env.rollout(max_steps=10)
    traj = policy(traj)
    print(traj["action"].shape)    # torch.Size([10, 1])

GRUModule

GRUModule is analogous to LSTMModule but wraps torch.nn.GRU. GRU has a single hidden state (no cell state), so its key triplets become pairs:
  • in_keys = [input_key, hidden_key]
  • out_keys = [output_key, ("next", hidden_key)]
All other parameters and behaviors — set_recurrent_mode, make_tensordict_primer, recurrent backends, and the python_based flag — are identical to LSTMModule.
input_size
int
required
Number of expected input features.
hidden_size
int
required
Number of features in the GRU hidden state.
num_layers
int
default:"1"
Number of stacked GRU layers.
bias
bool
default:"True"
Whether to include bias terms.
dropout
float
default:"0"
Dropout on intermediate GRU layer outputs.
python_based
bool
default:"False"
Use the Python GRU implementation for torch.vmap / torch.compile compatibility.
in_keys
list[str]
Pair [input_key, hidden_key].
out_keys
list[str]
Pair [output_key, ("next", hidden_key)].

GRUModule Example

from torchrl.envs import TransformedEnv, InitTracker, GymEnv
from torchrl.modules import MLP, GRUModule
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod

env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
obs_size = env.observation_spec["observation"].shape[-1]  # 3

gru = GRUModule(
    input_size=obs_size,
    hidden_size=64,
    in_keys=["observation", "rs"],
    out_keys=["intermediate", ("next", "rs")],
)
head = Mod(
    MLP(in_features=64, out_features=1, depth=0),
    in_keys=["intermediate"],
    out_keys=["action"],
)
policy = Seq(gru, head)
env.append_transform(gru.make_tensordict_primer())

td = env.reset()
print(policy(td)["action"].shape)  # torch.Size([1])

Recurrent Backends

Both LSTMModule and GRUModule support multiple compute backends for the recurrent pass when sequences contain episode resets mid-batch.
Splits the batch into per-trajectory chunks, pads them to the same length, processes with the cuDNN kernel, then unpads. Safe and stable; materializes extra memory for padding.

Matmul Precision

For the Triton backend, matmul precision is controlled separately via set_recurrent_matmul_precision:
from torchrl.modules import set_recurrent_matmul_precision, get_recurrent_matmul_precision

# Options: "ieee", "tf32", "tf32x3", "fast", "high-prec", "auto"
set_recurrent_matmul_precision("tf32")
print(get_recurrent_matmul_precision())
You can also override precision per-module via the recurrent_matmul_precision constructor argument.
RecurrentMatmulPrecision and RecurrentMatmulPrecisionUserMode are the enum classes backing these settings. "auto" derives the precision from torch.get_float32_matmul_precision() and the TORCHRL_RNN_PRECISION environment variable.

Low-Level Cells: LSTMCell, GRUCell, LSTM, GRU

TorchRL also exports Python-native implementations of the raw cell and multi-step modules, all compatible with torch.vmap and torch.compile:
  • LSTMCell — single-step LSTM cell (mirrors nn.LSTMCell).
  • GRUCell — single-step GRU cell (mirrors nn.GRUCell).
  • LSTM — multi-step LSTM (mirrors nn.LSTM); fully Python-based, vmap-compatible.
  • GRU — multi-step GRU (mirrors nn.GRU); fully Python-based.
These are useful when you need vectorized rollouts with torch.vmap, model-based RL with recurrent world models, or when compiling recurrent policies end-to-end with torch.compile.

Utilities

from torchrl.modules import (
    get_primers_from_module,   # discover all TensorDictPrimers in a policy
    canonicalize_rnn_subset,   # normalize subset indices for RNN slicing
    recurrent_mode,            # context manager: set_recurrent_mode(True/False)
    set_recurrent_mode,        # function or decorator form
)
get_primers_from_module(policy) traverses the module tree and collects all make_tensordict_primer() results from every embedded LSTMModule / GRUModule. This is the recommended way to set up primers when building complex policies with multiple recurrent sub-networks.
from torchrl.modules import get_primers_from_module
from torchrl.envs import TransformedEnv, GymEnv

env = GymEnv("Pendulum-v1")
# policy may contain multiple LSTMModules / GRUModules
for primer in get_primers_from_module(policy):
    env = TransformedEnv(env, primer)

Build docs developers (and LLMs) love