Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Partially observable environments — tasks where the agent cannot see the full environment state — are common in robotics, finance, and real-world control. A feedforward policy cannot integrate information across time steps; a recurrent policy can. TorchRL provides LSTMModule and GRUModule, TensorDict-compatible wrappers around torch.nn.LSTM and torch.nn.GRU that handle the mechanics of hidden-state routing automatically. This tutorial covers how to build recurrent policies, how the hidden-state lifecycle interacts with collectors and environments, and the performance backends available for large-scale training.
1
Why recurrent policies?
2
Standard feedforward policies map a single observation to an action: π(a | o_t). In a partially observable Markov decision process (POMDP), the single observation o_t does not contain enough information for optimal decisions. A recurrent policy accumulates a hidden state h_t across time steps: π(a_t, h_{t+1} | o_t, h_t).
3
The key engineering challenge is that h_t must be:
4
  • Carried step-to-step during collection (inference mode).
  • Reset to zero at episode boundaries.
  • Replayed along the time dimension during training (recurrent mode).
  • 5
    LSTMModule and GRUModule handle all three automatically when integrated with TransformedEnv and InitTracker.
    6
    Build an LSTM policy
    7
    LSTMModule wraps torch.nn.LSTM and connects it to the TensorDict data flow. The hidden state keys follow the convention ("next", "rs_h") and ("next", "rs_c") so that, after each step, the collector automatically copies ("next", "rs_h")"rs_h" for the next step.
    8
    import torch
    import torch.nn as nn
    from tensordict.nn import TensorDictModule, TensorDictSequential
    from torchrl.envs import GymEnv, TransformedEnv, InitTracker
    from torchrl.modules import LSTMModule, MLP
    
    # Build the base environment with InitTracker
    # InitTracker adds an "is_init" flag that tells LSTMModule when to zero hidden states
    base_env = GymEnv("Pendulum-v1")
    env = TransformedEnv(base_env, InitTracker())
    
    obs_dim = env.observation_spec["observation"].shape[-1]   # 3 for Pendulum-v1
    hidden_size = 64
    action_dim = env.action_spec.shape[-1]                    # 1 for Pendulum-v1
    
    # LSTMModule: obs → hidden representation + carries (h, c) across steps
    lstm_module = LSTMModule(
        input_size=obs_dim,
        hidden_size=hidden_size,
        num_layers=1,
        # in_keys: [input_feature, hidden_h, hidden_c]
        in_keys=["observation", "rs_h", "rs_c"],
        # out_keys: [output_feature, next_hidden_h, next_hidden_c]
        out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
    )
    
    # A simple head MLP that maps the LSTM output to an action
    mlp_head = TensorDictModule(
        MLP(in_features=hidden_size, out_features=action_dim, num_cells=[64]),
        in_keys=["lstm_out"],
        out_keys=["action"],
    )
    
    # Chain LSTM and MLP into a single policy
    policy = TensorDictSequential(lstm_module, mlp_head)
    
    9
    LSTMModule requires "is_init" in the input TensorDict. This flag marks episode beginnings so that the hidden state is reset to zero. Always wrap your environment with InitTracker() before using a recurrent policy with a Collector.
    10
    Register hidden-state priming with the environment
    11
    Before collecting data, the environment must know that "rs_h" and "rs_c" keys need to be pre-allocated. make_tensordict_primer() returns a TensorDictPrimer transform that injects zero-initialised hidden states into the reset TensorDict.
    12
    from torchrl.envs import TransformedEnv
    
    # make_tensordict_primer reads the LSTMModule's out_keys to determine
    # what shape and dtype to pre-allocate
    primer = lstm_module.make_tensordict_primer()
    env = env.append_transform(primer)
    
    # Alternatively, for a policy composed of multiple modules:
    from torchrl.modules.utils import get_primers_from_module
    primers = get_primers_from_module(policy)
    for p in primers:
        env = env.append_transform(p)
    
    # Verify everything is wired up
    from torchrl.envs.utils import check_env_specs
    check_env_specs(env)
    
    13
    Use GRUModule as a simpler alternative
    14
    GRUModule has only one hidden state vector (no cell state), so its interface is slightly simpler. All other concepts are identical.
    15
    from torchrl.modules import GRUModule
    
    gru_module = GRUModule(
        input_size=obs_dim,
        hidden_size=hidden_size,
        num_layers=1,
        # in_keys: [input_feature, hidden_state]
        in_keys=["observation", "rs"],
        # out_keys: [output_feature, next_hidden_state]
        out_keys=["gru_out", ("next", "rs")],
    )
    
    mlp_head_gru = TensorDictModule(
        MLP(in_features=hidden_size, out_features=action_dim, num_cells=[64]),
        in_keys=["gru_out"],
        out_keys=["action"],
    )
    
    gru_policy = TensorDictSequential(gru_module, mlp_head_gru)
    
    # Add primer for the GRU hidden state
    env_gru = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
    env_gru = env_gru.append_transform(gru_module.make_tensordict_primer())
    
    16
    Understand the recurrent state lifecycle
    17
    The LSTMModule and GRUModule operate in two distinct modes controlled by the set_recurrent_mode context manager (imported from torchrl.modules):
    18
    Step mode (collection)
    In step mode (default), the module processes a single time step. The hidden state from the previous step is read from "rs_h" / "rs_c" and the updated state is written to ("next", "rs_h") / ("next", "rs_c").TorchRL’s StepMDP transform and the Collector copy the ("next", ...) keys back to the root automatically between steps, so the hidden state flows seamlessly across an episode without any manual management.
    # Step mode: process one observation at a time
    # (this is the default; collectors use this mode)
    td = env.reset()
    for _ in range(5):
        td = policy(td)    # lstm_module reads rs_h, rs_c; writes (next, rs_h), (next, rs_c)
        td = env.step(td)
        # After step, (next, rs_h) is automatically carried to rs_h for the next call
    
    Recurrent mode (training)
    In recurrent mode, the module processes a full trajectory (the time dimension is the last batch dimension). This uses the multi-step LSTM/GRU forward pass, which is more efficient for training on stored rollouts.
    from torchrl.modules import set_recurrent_mode
    
    # Switch to recurrent mode using the context manager
    # (calling lstm_module.set_recurrent_mode() directly is not supported)
    with set_recurrent_mode(True):
        # Collect a trajectory (shape [T, ...])
        traj = env.rollout(policy=policy, max_steps=32)
    
        # Apply the LSTM over the full trajectory in one pass
        traj = lstm_module(traj)
    
    19
    Collect data with a recurrent policy
    20
    from torchrl.collectors import Collector
    
    FRAMES_PER_BATCH = 1024
    TOTAL_FRAMES = 100_000
    
    collector = Collector(
        create_env_fn=lambda: TransformedEnv(
            GymEnv("Pendulum-v1"),
            InitTracker(),
        ).append_transform(lstm_module.make_tensordict_primer()),
        policy=policy,
        frames_per_batch=FRAMES_PER_BATCH,
        total_frames=TOTAL_FRAMES,
        device="cpu",
    )
    
    # The collector automatically handles:
    # - resetting hidden states at episode boundaries (via "is_init")
    # - copying ("next", "rs_h") → "rs_h" between steps
    for batch in collector:
        # batch["rs_h"] and batch["rs_c"] are the per-step hidden states
        # batch[("next", "rs_h")] is the hidden state after each step
        print(batch.shape, batch.keys())
        break
    
    collector.shutdown()
    
    21
    When using MultiSyncCollector or MultiAsyncCollector with recurrent policies, each worker maintains its own hidden-state buffer. The make_tensordict_primer() transform must be included in every worker’s environment constructor.
    22
    Training on sequences
    23
    Recurrent policies require special care during training: a loss computed on individual steps ignores temporal dependencies. The standard approach is to train on fixed-length sequence segments.
    24
    from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
    from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
    from torchrl.objectives import ClipPPOLoss
    from torchrl.objectives.value.advantages import GAE
    from torchrl.modules import ValueOperator, set_recurrent_mode
    
    # Critic that also takes hidden state as input
    # (simplest option: feedforward critic on LSTM output)
    lstm_critic = LSTMModule(
        input_size=obs_dim,
        hidden_size=hidden_size,
        in_keys=["observation", "critic_rs_h", "critic_rs_c"],
        out_keys=["critic_lstm_out", ("next", "critic_rs_h"), ("next", "critic_rs_c")],
    )
    value_head = TensorDictModule(
        MLP(in_features=hidden_size, out_features=1, num_cells=[64]),
        in_keys=["critic_lstm_out"],
        out_keys=["state_value"],
    )
    critic = TensorDictSequential(lstm_critic, value_head)
    
    adv_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic)
    loss_module = ClipPPOLoss(
        actor_network=policy,
        critic_network=critic,
        clip_epsilon=0.2,
        entropy_coeff=0.01,
    )
    
    replay_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(FRAMES_PER_BATCH),
        sampler=SamplerWithoutReplacement(),
        batch_size=128,
    )
    optim = torch.optim.Adam(loss_module.parameters(), lr=3e-4)
    
    for batch in collector:
        # Switch LSTM to recurrent mode for advantage computation
        with set_recurrent_mode(True), torch.no_grad():
            batch = adv_module(batch)
    
        replay_buffer.extend(batch.reshape(-1))
    
        for minibatch in replay_buffer:
            optim.zero_grad(set_to_none=True)
            loss = loss_module(minibatch)
            (loss["loss_objective"] + loss["loss_critic"] + loss["loss_entropy"]).backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
            optim.step()
    
        collector.update_policy_weights_()
    
    25
    Performance backends
    26
    LSTMModule and GRUModule support three backends for the recurrent computation in training mode. Each trades off compilation compatibility, memory usage, and raw throughput:
    27
    BackendKeywordWhen to use"pad" (default)recurrent_backend="pad"Stable baseline; uses cuDNN under the hood. Cannot be used with recurrent_recompute."scan"recurrent_backend="scan"Requires hoptorch (pip install hoptorch). Avoids padding-induced memory waste on variable-length trajectories. Works with torch.compile."triton" (prototype)recurrent_backend="triton"CUDA only; requires triton>=2.2. Highest throughput on Ampere+ GPUs. Supports recurrent_recompute="full" for activation-memory trade-off."auto"recurrent_backend="auto"Uses "pad" in eager mode; switches to "scan" under torch.compile.
    28
    # Compile-friendly setup with the scan backend
    lstm_scan = LSTMModule(
        input_size=obs_dim,
        hidden_size=hidden_size,
        in_keys=["observation", "rs_h", "rs_c"],
        out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
        recurrent_backend="scan",     # requires: pip install hoptorch
    )
    
    # CUDA Triton backend with gradient checkpointing
    lstm_triton = LSTMModule(
        input_size=obs_dim,
        hidden_size=hidden_size,
        in_keys=["observation", "rs_h", "rs_c"],
        out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
        recurrent_backend="triton",   # requires: CUDA + pip install triton>=2.2
        recurrent_recompute="full",   # drop gate buffers from autograd graph
        recurrent_matmul_precision="tf32",  # fastest on Ampere+
    )
    
    29
    The recurrent_backend setting only affects recurrent mode (sequence training). During step-mode collection the module always runs cell-by-cell regardless of this setting.

    Full recurrent policy example

    """Minimal LSTM policy on Pendulum-v1 with PPO."""
    from __future__ import annotations
    
    import torch
    import torch.nn as nn
    from tensordict.nn import TensorDictModule, TensorDictSequential
    from torchrl.collectors import Collector
    from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
    from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
    from torchrl.envs import GymEnv, InitTracker, TransformedEnv, StepCounter
    from torchrl.modules import LSTMModule, MLP, ValueOperator, set_recurrent_mode
    from torchrl.objectives import ClipPPOLoss
    from torchrl.objectives.value.advantages import GAE
    
    OBS_DIM = 3       # Pendulum-v1
    ACT_DIM = 1
    HIDDEN = 64
    DEVICE = "cpu"
    
    def make_env():
        env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
        env.append_transform(StepCounter(max_steps=200))
        return env
    
    # Policy
    lstm = LSTMModule(
        input_size=OBS_DIM, hidden_size=HIDDEN,
        in_keys=["observation", "rs_h", "rs_c"],
        out_keys=["lstm_out", ("next", "rs_h"), ("next", "rs_c")],
    )
    head = TensorDictModule(
        MLP(HIDDEN, out_features=ACT_DIM * 2, num_cells=[64]),
        in_keys=["lstm_out"], out_keys=["loc", "scale"],
    )
    policy = TensorDictSequential(lstm, head)
    
    from torchrl.modules import ProbabilisticActor, TanhNormal
    from torchrl.envs import ExplorationType
    proof_env = make_env()
    actor = ProbabilisticActor(
        policy,
        in_keys=["loc", "scale"],
        spec=proof_env.full_action_spec_unbatched,
        distribution_class=TanhNormal,
        distribution_kwargs={"low": -2.0, "high": 2.0},
        return_log_prob=True,
        default_interaction_type=ExplorationType.RANDOM,
    )
    
    # Critic
    critic = ValueOperator(
        MLP(OBS_DIM, out_features=1, num_cells=[64, 64]),
        in_keys=["observation"],
    )
    
    # Env with primer
    env = make_env().append_transform(lstm.make_tensordict_primer())
    
    adv = GAE(gamma=0.99, lmbda=0.95, value_network=critic)
    loss = ClipPPOLoss(actor_network=actor, critic_network=critic, clip_epsilon=0.2)
    optim = torch.optim.Adam(loss.parameters(), lr=3e-4)
    
    collector = Collector(
        create_env_fn=lambda: make_env().append_transform(lstm.make_tensordict_primer()),
        policy=actor,
        frames_per_batch=512,
        total_frames=50_000,
    )
    rb = TensorDictReplayBuffer(
        storage=LazyTensorStorage(512),
        sampler=SamplerWithoutReplacement(),
        batch_size=64,
    )
    
    for data in collector:
        with set_recurrent_mode(True), torch.no_grad():
            data = adv(data)
        rb.extend(data.reshape(-1))
        for batch in rb:
            optim.zero_grad(set_to_none=True)
            l = loss(batch)
            (l["loss_objective"] + l["loss_critic"] + l["loss_entropy"]).backward()
            torch.nn.utils.clip_grad_norm_(loss.parameters(), 0.5)
            optim.step()
        collector.update_policy_weights_()
    
    collector.shutdown()
    

    Key concepts summary

    ConceptAPIDescription
    Recurrent module (LSTM)LSTMModuleWraps nn.LSTM; keys: [obs, h, c][out, (next,h), (next,c)]
    Recurrent module (GRU)GRUModuleWraps nn.GRU; keys: [obs, h][out, (next,h)]
    Episode boundary markerInitTrackerAdds "is_init" flag; used by LSTM/GRU to zero hidden states
    Hidden state injectionmake_tensordict_primer()Pre-allocates zero hidden states on environment reset
    Mode switchingset_recurrent_mode(True/False)Context manager from torchrl.modules for step vs sequence mode
    Scan backendrecurrent_backend="scan"torch.compile-compatible; needs hoptorch
    Triton backendrecurrent_backend="triton"GPU kernel; needs triton>=2.2

    Build docs developers (and LLMs) love