Partially observable environments — tasks where the agent cannot see the full environment state — are common in robotics, finance, and real-world control. A feedforward policy cannot integrate information across time steps; a recurrent policy can. TorchRL providesDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
LSTMModule and GRUModule, TensorDict-compatible wrappers around torch.nn.LSTM and torch.nn.GRU that handle the mechanics of hidden-state routing automatically. This tutorial covers how to build recurrent policies, how the hidden-state lifecycle interacts with collectors and environments, and the performance backends available for large-scale training.
Standard feedforward policies map a single observation to an action:
π(a | o_t). In a partially observable Markov decision process (POMDP), the single observation o_t does not contain enough information for optimal decisions. A recurrent policy accumulates a hidden state h_t across time steps: π(a_t, h_{t+1} | o_t, h_t).LSTMModule and GRUModule handle all three automatically when integrated with TransformedEnv and InitTracker.LSTMModule wraps torch.nn.LSTM and connects it to the TensorDict data flow. The hidden state keys follow the convention ("next", "rs_h") and ("next", "rs_c") so that, after each step, the collector automatically copies ("next", "rs_h") → "rs_h" for the next step.import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs import GymEnv, TransformedEnv, InitTracker
from torchrl.modules import LSTMModule, MLP
# Build the base environment with InitTracker
# InitTracker adds an "is_init" flag that tells LSTMModule when to zero hidden states
base_env = GymEnv("Pendulum-v1")
env = TransformedEnv(base_env, InitTracker())
obs_dim = env.observation_spec["observation"].shape[-1] # 3 for Pendulum-v1
hidden_size = 64
action_dim = env.action_spec.shape[-1] # 1 for Pendulum-v1
# LSTMModule: obs → hidden representation + carries (h, c) across steps
lstm_module = LSTMModule(
input_size=obs_dim,
hidden_size=hidden_size,
num_layers=1,
# in_keys: [input_feature, hidden_h, hidden_c]
in_keys=["observation", "rs_h", "rs_c"],
# out_keys: [output_feature, next_hidden_h, next_hidden_c]
out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
)
# A simple head MLP that maps the LSTM output to an action
mlp_head = TensorDictModule(
MLP(in_features=hidden_size, out_features=action_dim, num_cells=[64]),
in_keys=["lstm_out"],
out_keys=["action"],
)
# Chain LSTM and MLP into a single policy
policy = TensorDictSequential(lstm_module, mlp_head)
LSTMModule requires "is_init" in the input TensorDict. This flag marks episode beginnings so that the hidden state is reset to zero. Always wrap your environment with InitTracker() before using a recurrent policy with a Collector.Before collecting data, the environment must know that
"rs_h" and "rs_c" keys need to be pre-allocated. make_tensordict_primer() returns a TensorDictPrimer transform that injects zero-initialised hidden states into the reset TensorDict.from torchrl.envs import TransformedEnv
# make_tensordict_primer reads the LSTMModule's out_keys to determine
# what shape and dtype to pre-allocate
primer = lstm_module.make_tensordict_primer()
env = env.append_transform(primer)
# Alternatively, for a policy composed of multiple modules:
from torchrl.modules.utils import get_primers_from_module
primers = get_primers_from_module(policy)
for p in primers:
env = env.append_transform(p)
# Verify everything is wired up
from torchrl.envs.utils import check_env_specs
check_env_specs(env)
GRUModule has only one hidden state vector (no cell state), so its interface is slightly simpler. All other concepts are identical.from torchrl.modules import GRUModule
gru_module = GRUModule(
input_size=obs_dim,
hidden_size=hidden_size,
num_layers=1,
# in_keys: [input_feature, hidden_state]
in_keys=["observation", "rs"],
# out_keys: [output_feature, next_hidden_state]
out_keys=["gru_out", ("next", "rs")],
)
mlp_head_gru = TensorDictModule(
MLP(in_features=hidden_size, out_features=action_dim, num_cells=[64]),
in_keys=["gru_out"],
out_keys=["action"],
)
gru_policy = TensorDictSequential(gru_module, mlp_head_gru)
# Add primer for the GRU hidden state
env_gru = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
env_gru = env_gru.append_transform(gru_module.make_tensordict_primer())
The
LSTMModule and GRUModule operate in two distinct modes controlled by the set_recurrent_mode context manager (imported from torchrl.modules): Step mode (collection)
In step mode (default), the module processes a single time step. The hidden state from the previous step is read from
"rs_h" / "rs_c" and the updated state is written to ("next", "rs_h") / ("next", "rs_c").TorchRL’s StepMDP transform and the Collector copy the ("next", ...) keys back to the root automatically between steps, so the hidden state flows seamlessly across an episode without any manual management. Recurrent mode (training)
In recurrent mode, the module processes a full trajectory (the time dimension is the last batch dimension). This uses the multi-step LSTM/GRU forward pass, which is more efficient for training on stored rollouts.
from torchrl.collectors import Collector
FRAMES_PER_BATCH = 1024
TOTAL_FRAMES = 100_000
collector = Collector(
create_env_fn=lambda: TransformedEnv(
GymEnv("Pendulum-v1"),
InitTracker(),
).append_transform(lstm_module.make_tensordict_primer()),
policy=policy,
frames_per_batch=FRAMES_PER_BATCH,
total_frames=TOTAL_FRAMES,
device="cpu",
)
# The collector automatically handles:
# - resetting hidden states at episode boundaries (via "is_init")
# - copying ("next", "rs_h") → "rs_h" between steps
for batch in collector:
# batch["rs_h"] and batch["rs_c"] are the per-step hidden states
# batch[("next", "rs_h")] is the hidden state after each step
print(batch.shape, batch.keys())
break
collector.shutdown()
When using
MultiSyncCollector or MultiAsyncCollector with recurrent policies, each worker maintains its own hidden-state buffer. The make_tensordict_primer() transform must be included in every worker’s environment constructor.Recurrent policies require special care during training: a loss computed on individual steps ignores temporal dependencies. The standard approach is to train on fixed-length sequence segments.
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.modules import ValueOperator, set_recurrent_mode
# Critic that also takes hidden state as input
# (simplest option: feedforward critic on LSTM output)
lstm_critic = LSTMModule(
input_size=obs_dim,
hidden_size=hidden_size,
in_keys=["observation", "critic_rs_h", "critic_rs_c"],
out_keys=["critic_lstm_out", ("next", "critic_rs_h"), ("next", "critic_rs_c")],
)
value_head = TensorDictModule(
MLP(in_features=hidden_size, out_features=1, num_cells=[64]),
in_keys=["critic_lstm_out"],
out_keys=["state_value"],
)
critic = TensorDictSequential(lstm_critic, value_head)
adv_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic)
loss_module = ClipPPOLoss(
actor_network=policy,
critic_network=critic,
clip_epsilon=0.2,
entropy_coeff=0.01,
)
replay_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(FRAMES_PER_BATCH),
sampler=SamplerWithoutReplacement(),
batch_size=128,
)
optim = torch.optim.Adam(loss_module.parameters(), lr=3e-4)
for batch in collector:
# Switch LSTM to recurrent mode for advantage computation
with set_recurrent_mode(True), torch.no_grad():
batch = adv_module(batch)
replay_buffer.extend(batch.reshape(-1))
for minibatch in replay_buffer:
optim.zero_grad(set_to_none=True)
loss = loss_module(minibatch)
(loss["loss_objective"] + loss["loss_critic"] + loss["loss_entropy"]).backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
optim.step()
collector.update_policy_weights_()
LSTMModule and GRUModule support three backends for the recurrent computation in training mode. Each trades off compilation compatibility, memory usage, and raw throughput:"pad" (default)recurrent_backend="pad"recurrent_recompute."scan"recurrent_backend="scan"hoptorch (pip install hoptorch). Avoids padding-induced memory waste on variable-length trajectories. Works with torch.compile."triton" (prototype)recurrent_backend="triton"triton>=2.2. Highest throughput on Ampere+ GPUs. Supports recurrent_recompute="full" for activation-memory trade-off."auto"recurrent_backend="auto""pad" in eager mode; switches to "scan" under torch.compile.# Compile-friendly setup with the scan backend
lstm_scan = LSTMModule(
input_size=obs_dim,
hidden_size=hidden_size,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
recurrent_backend="scan", # requires: pip install hoptorch
)
# CUDA Triton backend with gradient checkpointing
lstm_triton = LSTMModule(
input_size=obs_dim,
hidden_size=hidden_size,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
recurrent_backend="triton", # requires: CUDA + pip install triton>=2.2
recurrent_recompute="full", # drop gate buffers from autograd graph
recurrent_matmul_precision="tf32", # fastest on Ampere+
)
Full recurrent policy example
Key concepts summary
| Concept | API | Description |
|---|---|---|
| Recurrent module (LSTM) | LSTMModule | Wraps nn.LSTM; keys: [obs, h, c] → [out, (next,h), (next,c)] |
| Recurrent module (GRU) | GRUModule | Wraps nn.GRU; keys: [obs, h] → [out, (next,h)] |
| Episode boundary marker | InitTracker | Adds "is_init" flag; used by LSTM/GRU to zero hidden states |
| Hidden state injection | make_tensordict_primer() | Pre-allocates zero hidden states on environment reset |
| Mode switching | set_recurrent_mode(True/False) | Context manager from torchrl.modules for step vs sequence mode |
| Scan backend | recurrent_backend="scan" | torch.compile-compatible; needs hoptorch |
| Triton backend | recurrent_backend="triton" | GPU kernel; needs triton>=2.2 |