Recurrent policies are essential whenever the environment is partially observed and a single frame is not enough to determine the optimal action. Standard PyTorchDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
nn.LSTM and nn.GRU modules require you to thread hidden states manually between calls, which conflicts with TorchRL’s TensorDict-based data flow. TorchRL solves this with LSTMModule and GRUModule: they store hidden states as ordinary TensorDict keys under well-defined names, making recurrent state part of the same structured dictionary that carries observations, actions, and rewards. Collectors automatically carry hidden states forward and reset them at episode boundaries.
Why Hidden States Need Special Handling
In a standard collector loop, each step callspolicy(tensordict) and the tensordict flows through the whole pipeline — transforms, replay buffers, samplers. If the hidden state is just a Python variable held outside the tensordict, it gets silently lost when the tensordict is sliced, shuffled, or sent across processes. By placing hidden states in the tensordict under keys like ("next", "recurrent_state_h"), TorchRL ensures they are:
- Automatically carried forward: each step reads
recurrent_state_hand writes("next", "recurrent_state_h"), so theStepCounter/RolloutWrapperinfrastructure knows to copy them. - Reset at episode boundaries:
InitTrackerandTensorDictPrimertransforms zero them whenis_init=True. - Batch-compatible: they live inside the same batched tensordict as observations, so they are correctly sliced for sub-batches and padded for variable-length sequences.
LSTMModule
LSTMModule wraps torch.nn.LSTM (or TorchRL’s Python-native LSTM implementation) with TensorDict-compatible I/O. It has two modes of operation:
- Single-step mode (default): processes one time step at a time, updating hidden states in-place. Used during environment collection.
- Recurrent mode: processes a full time-sequence of shape
[B, T, *], enabling truncated BPTT during training. Enabled withset_recurrent_mode(True)or therecurrent_modecontext manager.
Number of expected input features (observation / embedding dimensionality).
Number of features in the LSTM hidden state
h (and cell state c).Number of stacked LSTM layers.
Whether to include bias terms
b_ih and b_hh.Dropout probability on outputs of each LSTM layer except the last.
When
True, uses TorchRL’s fully Python-implemented LSTMCell instead
of the cuDNN kernel. Required for torch.vmap and torch.compile.Backend used when trajectories reset mid-batch. Options:
"pad"— splits trajectories and pads to uniform length (default)."scan"— uses a scan loop viahoptorch; avoids materialization of padded chunks."triton"— prototype Triton kernels (CUDA only, requires Triton ≥ 2.2)."auto"— uses"pad"in eager mode and"scan"undertorch.compile.
Shorthand for
in_keys when hidden state key names follow the default
convention. Exclusive with in_keys.A triplet
[input_key, hidden_h_key, hidden_c_key] specifying what to
read from the input TensorDict. Exclusive with in_key.Shorthand for
out_keys. Exclusive with out_keys.A triplet
[output_key, next_hidden_h_key, next_hidden_c_key]. For
correct rollout behavior, hidden output keys should be nested under
"next", e.g. [("next", "rs_h"), ("next", "rs_c")].Default value for
recurrent_mode when no context manager is active.
Defaults to False.Hidden State Key Convention
TorchRL uses the"next" nesting convention to propagate state between steps. The pattern is:
StepCountTransform (or RolloutWrapper) copies td["next"]["rs_h"] → td["rs_h"] automatically before the next step.
set_recurrent_mode
set_recurrent_mode (from torchrl.modules) switches between single-step and sequence-processing behavior. Use it as a context manager during training or set the default via the default_recurrent_mode constructor argument:
The instance method
lstm_module.set_recurrent_mode() was removed in
TorchRL v0.8 and now raises RuntimeError. Use the set_recurrent_mode
context manager from torchrl.modules or the default_recurrent_mode
constructor argument instead.make_tensordict_primer
LSTMModule.make_tensordict_primer() returns a TensorDictPrimer transform that initializes hidden-state keys with zeros in the environment’s observation tensordict. Apply it to the environment via TransformedEnv.
Full Example: LSTM-Based Actor
GRUModule
GRUModule is analogous to LSTMModule but wraps torch.nn.GRU. GRU has a single hidden state (no cell state), so its key triplets become pairs:
in_keys = [input_key, hidden_key]out_keys = [output_key, ("next", hidden_key)]
set_recurrent_mode, make_tensordict_primer, recurrent backends, and the python_based flag — are identical to LSTMModule.
Number of expected input features.
Number of features in the GRU hidden state.
Number of stacked GRU layers.
Whether to include bias terms.
Dropout on intermediate GRU layer outputs.
Use the Python GRU implementation for
torch.vmap / torch.compile
compatibility.Pair
[input_key, hidden_key].Pair
[output_key, ("next", hidden_key)].GRUModule Example
Recurrent Backends
BothLSTMModule and GRUModule support multiple compute backends for the recurrent pass when sequences contain episode resets mid-batch.
- pad (default)
- scan
- triton (prototype)
- auto
Splits the batch into per-trajectory chunks, pads them to the same
length, processes with the cuDNN kernel, then unpads. Safe and stable;
materializes extra memory for padding.
Matmul Precision
For the Triton backend, matmul precision is controlled separately viaset_recurrent_matmul_precision:
recurrent_matmul_precision constructor argument.
RecurrentMatmulPrecision and RecurrentMatmulPrecisionUserMode are the
enum classes backing these settings. "auto" derives the precision from
torch.get_float32_matmul_precision() and the TORCHRL_RNN_PRECISION
environment variable.Low-Level Cells: LSTMCell, GRUCell, LSTM, GRU
TorchRL also exports Python-native implementations of the raw cell and multi-step modules, all compatible withtorch.vmap and torch.compile:
LSTMCell— single-step LSTM cell (mirrorsnn.LSTMCell).GRUCell— single-step GRU cell (mirrorsnn.GRUCell).LSTM— multi-step LSTM (mirrorsnn.LSTM); fully Python-based, vmap-compatible.GRU— multi-step GRU (mirrorsnn.GRU); fully Python-based.
torch.vmap, model-based RL with recurrent world models, or when compiling recurrent policies end-to-end with torch.compile.
Utilities
get_primers_from_module(policy) traverses the module tree and collects all make_tensordict_primer() results from every embedded LSTMModule / GRUModule. This is the recommended way to set up primers when building complex policies with multiple recurrent sub-networks.