Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

This guide walks you through training a continuous-control agent using Proximal Policy Optimization (PPO) on the InvertedDoublePendulum-v4 Gymnasium environment. By the end you will have a working training loop that uses TorchRL’s GymEnv wrapper, a stochastic ProbabilisticActor with a TanhNormal distribution, a Collector for batching rollouts, and GAE + ClipPPOLoss for computing advantages and gradients — all connected through the shared TensorDict data model.
1

Install TorchRL

Install TorchRL and the Gymnasium continuous-control extras. TorchRL requires Python 3.10+ and PyTorch 2.1+.
pip install torchrl
pip install "gymnasium[mujoco]"
Verify the install by importing the library:
import torchrl
print(torchrl.__version__)  # e.g. 0.13.0
If you plan to run on CUDA and want hardware-accelerated prioritized replay buffers, see the Installation guide for the optional CUDA wheel.
2

Create the Environment

TorchRL wraps external simulators with a uniform TensorDict-based API. GymEnv creates a Gymnasium environment and handles device placement. TransformedEnv adds a composable transform stack — here we normalize observations, convert float64 outputs to float32, track cumulative episode rewards, and count steps.
from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    RewardSum,
    StepCounter,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs

# Wrap a Gymnasium environment with TorchRL's uniform API.
base_env = GymEnv("InvertedDoublePendulum-v4", device="cpu")

env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        RewardSum(),
        StepCounter(),
    ),
)

# Initialize the normalization statistics from 1000 random steps.
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

# Validate that specs and actual rollout data are consistent.
check_env_specs(env)

print("observation_spec:", env.observation_spec)
print("action_spec:     ", env.action_spec)
check_env_specs runs a short rollout and compares its output shapes and dtypes against the declared specs. If it returns without error, your environment is correctly configured.
You can append or remove transforms at any time with env.append_transform(t) or by indexing env.transform. Transforms participate in spec propagation, so specs always reflect the current transform stack.
3

Build the Actor and Critic

PPO uses a stochastic policy that outputs a distribution over actions. We build the actor in three stages: a neural network that maps observations to distribution parameters, a TensorDictModule wrapper that declares explicit input/output keys, and a ProbabilisticActor that constructs and samples from a TanhNormal distribution.The critic is a simpler ValueOperator that maps observations to a scalar state-value estimate.
import torch
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.envs.utils import ExplorationType

num_cells = 256  # hidden layer width

# ── Actor ──────────────────────────────────────────────────────────────────
# The network outputs loc and scale for a TanhNormal distribution.
actor_net = nn.Sequential(
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec.shape[-1]),
    NormalParamExtractor(),       # splits output → (loc, scale)
)

# Wrap with explicit TensorDict key contracts.
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

# ProbabilisticActor samples actions and records log-probabilities.
policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec_unbatched.space.low,
        "high": env.action_spec_unbatched.space.high,
    },
    return_log_prob=True,
    default_interaction_type=ExplorationType.RANDOM,
)

# ── Critic ─────────────────────────────────────────────────────────────────
value_net = nn.Sequential(
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(1),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

# Trigger LazyLinear shape inference by running a reset observation.
policy_module.eval()
value_module.eval()
print("actor output: ", policy_module(env.reset()))
print("critic output:", value_module(env.reset()))
Both modules stay in eval() mode throughout training. TorchRL’s ExplorationType context managers control stochastic vs. deterministic action selection independently from PyTorch’s train/eval module state.
4

Set Up the Collector

A Collector owns the environment execution loop. It steps the policy in the environment, accumulates data into batches, and yields TensorDict instances with shape [frames_per_batch]. You iterate over it like a standard Python iterator.
from torchrl.collectors import Collector

frames_per_batch = 1000
total_frames = 1_000_000

collector = Collector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device="cpu",
)
The Collector returns a new TensorDict on each iteration. The batch contains observations, actions, log-probabilities, rewards, done flags, and any other keys written by the environment or policy transforms. No manual loop management is required — call collector.shutdown() when training ends.
For larger workloads, swap Collector for MultiSyncCollector or MultiAsyncCollector to parallelize across CPU workers, or use a distributed collector to run on separate machines — the training loop code stays identical.
5

Configure GAE and ClipPPOLoss

TorchRL’s GAE module computes Generalized Advantage Estimates in-place on the collected TensorDict. ClipPPOLoss reads the resulting "advantage" and "value_target" keys alongside the log-probabilities the policy recorded during collection.
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# Hyperparameters
lr = 3e-4
gamma = 0.99
lmbda = 0.95
clip_epsilon = 0.2
entropy_eps = 1e-4
sub_batch_size = 64
num_epochs = 10

# GAE advantage estimator — updates tensordict with "advantage" and "value_target".
advantage_module = GAE(
    gamma=gamma,
    lmbda=lmbda,
    value_network=value_module,
    average_gae=True,
)

# PPO clipped objective with entropy bonus and critic loss.
loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coeff=entropy_eps,
    critic_coeff=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

# Replay buffer for sub-batch sampling inside each epoch.
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)
The loss returns a TensorDict of named scalar losses. We combine three of them:
KeyMeaning
loss_objectiveClipped policy-gradient loss
loss_criticValue network regression loss
loss_entropyEntropy bonus (negated)
6

Run the Training Loop

With all pieces assembled, the training loop is a straightforward iteration over the collector. For each collected batch we compute advantages, fill the replay buffer, and run multiple epochs of mini-batch PPO updates.
from collections import defaultdict
import tqdm
from torchrl.envs.utils import set_exploration_type, ExplorationType

logs = defaultdict(list)
pbar = tqdm.tqdm(total=total_frames)
max_grad_norm = 1.0

for i, tensordict_data in enumerate(collector):
    # ── Inner optimization loop ──────────────────────────────────────────
    for _ in range(num_epochs):
        # Compute GAE advantages (writes "advantage" and "value_target" in-place).
        advantage_module(tensordict_data)

        # Flatten time and batch dimensions, load into the replay buffer.
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())

        # Mini-batch updates.
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata)

            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            loss_value.backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    # ── Logging ───────────────────────────────────────────────────────────
    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())

    # Periodic evaluation with deterministic actions.
    if i % 10 == 0:
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            eval_rollout = env.rollout(1000, policy_module)
        logs["eval reward (sum)"].append(
            eval_rollout["next", "reward"].sum().item()
        )
        del eval_rollout

    scheduler.step()

collector.shutdown()
env.close()
After 1 M environment steps the agent should reliably balance the inverted double pendulum and reach the maximum episode length of 1000 steps.

Next Steps

Environments & Transforms

Explore the full environment transform library: image pre-processing, reward scaling, action masking, frame stacking, and more.

Collectors

Learn how to parallelize data collection with MultiSyncCollector and MultiAsyncCollector for faster wall-clock training.

Replay Buffers

Deep-dive into prioritized replay, memmap-backed storage, HER, and offline dataset loading.

SOTA Implementations

Browse complete, research-ready implementations of SAC, DQN, TD3, Dreamer, Decision Transformer, MAPPO, and GRPO in sota-implementations/.

Build docs developers (and LLMs) love