Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL structures every RL algorithm around the same composable building blocks: a typed environment, a policy and value network expressed as TensorDictModules, a data Collector that streams batches, and a loss module that consumes those batches. This tutorial walks you through assembling all of those pieces into a working Proximal Policy Optimization (PPO) loop for a continuous-control MuJoCo task. The code shown here is drawn directly from the SOTA implementation in the TorchRL repository.
1
Install dependencies
2
pip install torchrl tensordict torch gymnasium[mujoco]
3
GPU training is optional. All examples below run on CPU; switch device="cuda" if a GPU is available.
4
Create and transform the environment
5
TorchRL wraps Gymnasium environments with GymEnv and chains preprocessing steps via TransformedEnv. Each transform is appended in order and runs at every step.
6
import torch
from torchrl.envs import (
    ClipTransform,
    DoubleToFloat,
    GymEnv,
    RewardSum,
    StepCounter,
    TransformedEnv,
    VecNorm,
)

def make_env(env_name: str = "HalfCheetah-v4", device: str = "cpu") -> TransformedEnv:
    # GymEnv wraps a Gymnasium environment and returns TensorDicts
    env = GymEnv(env_name, device=device)
    env = TransformedEnv(env)
    # VecNorm normalises observations with a running exponential mean/variance
    env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
    env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
    # RewardSum accumulates episode returns under "episode_reward"
    env.append_transform(RewardSum())
    # StepCounter adds a "step_count" key to every transition
    env.append_transform(StepCounter())
    env.append_transform(DoubleToFloat(in_keys=["observation"]))
    return env

env = make_env()
print(env.observation_spec)
print(env.action_spec)
7
The environment’s observation_spec and action_spec describe the shape, dtype, and bounds of every tensor the environment reads or writes. These specs are used in later steps to automatically size network inputs and outputs.
8
Build the actor (policy) network
9
PPO uses a ProbabilisticActor: a wrapper that takes a deterministic network producing distribution parameters and adds stochastic sampling. For continuous actions we use a TanhNormal distribution, which squashes samples into the action bounds.
10
import torch.nn as nn
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.envs import ExplorationType
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal

def make_actor(env: TransformedEnv, device: str = "cpu") -> ProbabilisticActor:
    obs_shape = env.observation_spec["observation"].shape
    num_outputs = env.action_spec_unbatched.shape[-1]

    # The MLP predicts only the location (mean) of the Gaussian
    policy_mlp = MLP(
        in_features=obs_shape[-1],
        activation_class=nn.Tanh,
        out_features=num_outputs,
        num_cells=[64, 64],
        device=device,
    )
    # Orthogonal weight initialisation (standard for PPO)
    for layer in policy_mlp.modules():
        if isinstance(layer, nn.Linear):
            nn.init.orthogonal_(layer.weight, 1.0)
            layer.bias.data.zero_()

    # AddStateIndependentNormalScale appends a learnable log-std parameter
    policy_mlp = nn.Sequential(
        policy_mlp,
        AddStateIndependentNormalScale(num_outputs, scale_lb=1e-8).to(device),
    )

    # TensorDictModule connects the network to TensorDict keys
    actor_module = TensorDictModule(
        module=policy_mlp,
        in_keys=["observation"],
        out_keys=["loc", "scale"],
    )

    # ProbabilisticActor wraps the module and attaches the distribution
    actor = ProbabilisticActor(
        actor_module,
        in_keys=["loc", "scale"],
        spec=env.full_action_spec_unbatched.to(device),
        distribution_class=TanhNormal,
        distribution_kwargs={
            "low": env.action_spec_unbatched.space.low.to(device),
            "high": env.action_spec_unbatched.space.high.to(device),
            "tanh_loc": False,
        },
        return_log_prob=True,            # needed for PPO importance weights
        default_interaction_type=ExplorationType.RANDOM,
    )
    return actor
11
return_log_prob=True tells ProbabilisticActor to write the log-probability of the sampled action under the key "sample_log_prob". ClipPPOLoss reads this key automatically.
12
Build the critic (value) network
13
The critic predicts a scalar state value. ValueOperator is a thin wrapper around an nn.Module that registers the correct in/out keys for downstream loss modules.
14
from torchrl.modules import ValueOperator

def make_critic(env: TransformedEnv, device: str = "cpu") -> ValueOperator:
    obs_shape = env.observation_spec["observation"].shape

    value_mlp = MLP(
        in_features=obs_shape[-1],
        activation_class=nn.Tanh,
        out_features=1,
        num_cells=[64, 64],
        device=device,
    )
    for layer in value_mlp.modules():
        if isinstance(layer, nn.Linear):
            nn.init.orthogonal_(layer.weight, 0.01)
            layer.bias.data.zero_()

    critic = ValueOperator(value_mlp, in_keys=["observation"])
    return critic
15
Set up GAE and ClipPPOLoss
16
GAE (Generalized Advantage Estimation) computes advantage and value targets from a batch of transitions. ClipPPOLoss implements the clipped surrogate objective together with a critic loss and an entropy bonus.
17
from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE

device = "cpu"
actor = make_actor(env, device)
critic = make_critic(env, device)

# GAE wraps the value network and is applied with torch.no_grad() at training time
adv_module = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=False,
    device=device,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=0.2,
    loss_critic_type="smooth_l1",
    entropy_coeff=0.01,
    critic_coeff=0.5,
    normalize_advantage=True,
)

# group_optimizers produces a single optimizer from two; useful for LR scheduling
actor_optim = torch.optim.Adam(actor.parameters(), lr=3e-4, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=3e-4, eps=1e-5)
optim = group_optimizers(actor_optim, critic_optim)
18
ClipPPOLoss looks up keys like "observation", "action", "sample_log_prob", and "advantage" automatically when they match the defaults. Use loss_module.set_keys(...) to override them for non-standard environments.
19
Create the data Collector
20
Collector runs the policy in the environment and yields TensorDict batches. It handles device placement, auto-reset between episodes, and optional torch.compile acceleration.
21
from torchrl.collectors import Collector

FRAMES_PER_BATCH = 2048
TOTAL_FRAMES = 1_000_000

collector = Collector(
    create_env_fn=make_env,        # called once to spin up the env
    policy=actor,
    frames_per_batch=FRAMES_PER_BATCH,
    total_frames=TOTAL_FRAMES,
    device=device,
    max_frames_per_traj=-1,        # -1 means no forced resets
)
22
Create the replay buffer for mini-batch sampling
23
PPO reuses each collected batch for several epochs of gradient updates. A TensorDictReplayBuffer with SamplerWithoutReplacement provides epoch-level mini-batching without repetition.
24
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

MINI_BATCH_SIZE = 256

data_buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(FRAMES_PER_BATCH, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=MINI_BATCH_SIZE,
)
25
Run the training loop
26
The outer loop iterates over the collector; the inner loop performs multiple epochs of PPO updates on the collected batch.
27
import tqdm
from torchrl.envs import set_exploration_type, ExplorationType

PPO_EPOCHS = 10
NUM_MINI_BATCHES = FRAMES_PER_BATCH // MINI_BATCH_SIZE

pbar = tqdm.tqdm(total=TOTAL_FRAMES)
collected_frames = 0

for data in collector:
    frames_in_batch = data.numel()
    collected_frames += frames_in_batch
    pbar.update(frames_in_batch)

    for _ in range(PPO_EPOCHS):
        # Compute advantages and value targets (no gradients needed here)
        with torch.no_grad():
            data = adv_module(data)

        # Fill the replay buffer with the flattened batch
        data_buffer.extend(data.reshape(-1))

        for batch in data_buffer:
            optim.zero_grad(set_to_none=True)
            loss = loss_module(batch)

            # Sum the three PPO loss terms
            total_loss = (
                loss["loss_objective"]
                + loss["loss_critic"]
                + loss["loss_entropy"]
            )
            total_loss.backward()
            # Gradient clipping is standard practice for PPO
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=0.5)
            optim.step()

    # Push the updated weights back to the collector's copy of the policy
    collector.update_policy_weights_()

    # Log episode reward (only non-NaN entries correspond to completed episodes)
    ep_rewards = data["next", "episode_reward"][data["next", "done"]]
    if len(ep_rewards):
        pbar.set_description(f"reward={ep_rewards.mean().item():.2f}")

collector.shutdown()
env.close()
28
Evaluate the trained policy
29
Run a deterministic rollout (no exploration noise) using set_exploration_type.
30
test_env = make_env(device=device)
test_env.eval()

with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
    actor.eval()
    td = test_env.rollout(
        policy=actor,
        max_steps=1000,
        auto_reset=True,
        break_when_any_done=True,
    )
    episode_reward = td["next", "episode_reward"][td["next", "done"]]
    print(f"Test reward: {episode_reward.mean().item():.2f}")

test_env.close()

Putting it all together

"""Minimal self-contained PPO for HalfCheetah-v4."""
from __future__ import annotations

import torch
import torch.nn as nn
import tqdm

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.collectors import Collector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import (
    ClipTransform, DoubleToFloat, ExplorationType,
    GymEnv,
    RewardSum, StepCounter, TransformedEnv, VecNorm,
    set_exploration_type,
)
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE

DEVICE = "cpu"
ENV_NAME = "HalfCheetah-v4"
FRAMES_PER_BATCH = 2048
TOTAL_FRAMES = 500_000
MINI_BATCH_SIZE = 256
PPO_EPOCHS = 10
LR = 3e-4

def make_env():
    env = GymEnv(ENV_NAME, device=DEVICE)
    env = TransformedEnv(env)
    env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
    env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
    env.append_transform(RewardSum())
    env.append_transform(StepCounter())
    env.append_transform(DoubleToFloat(in_keys=["observation"]))
    return env

env = make_env()
obs_dim = env.observation_spec["observation"].shape[-1]
act_dim = env.action_spec_unbatched.shape[-1]

# Actor
policy_mlp = nn.Sequential(
    MLP(obs_dim, activation_class=nn.Tanh, out_features=act_dim, num_cells=[64, 64]),
    AddStateIndependentNormalScale(act_dim, scale_lb=1e-8),
)
actor = ProbabilisticActor(
    TensorDictModule(policy_mlp, in_keys=["observation"], out_keys=["loc", "scale"]),
    in_keys=["loc", "scale"],
    spec=env.full_action_spec_unbatched,
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec_unbatched.space.low,
        "high": env.action_spec_unbatched.space.high,
        "tanh_loc": False,
    },
    return_log_prob=True,
    default_interaction_type=ExplorationType.RANDOM,
)

# Critic
critic = ValueOperator(
    MLP(obs_dim, activation_class=nn.Tanh, out_features=1, num_cells=[64, 64]),
    in_keys=["observation"],
)

adv_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic, average_gae=False)
loss_module = ClipPPOLoss(
    actor_network=actor, critic_network=critic,
    clip_epsilon=0.2, entropy_coeff=0.01, critic_coeff=0.5, normalize_advantage=True,
)
optim = group_optimizers(
    torch.optim.Adam(actor.parameters(), lr=LR, eps=1e-5),
    torch.optim.Adam(critic.parameters(), lr=LR, eps=1e-5),
)

collector = Collector(make_env, actor, frames_per_batch=FRAMES_PER_BATCH, total_frames=TOTAL_FRAMES)
data_buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(FRAMES_PER_BATCH),
    sampler=SamplerWithoutReplacement(),
    batch_size=MINI_BATCH_SIZE,
)

pbar = tqdm.tqdm(total=TOTAL_FRAMES)
for data in collector:
    pbar.update(data.numel())
    for _ in range(PPO_EPOCHS):
        with torch.no_grad():
            data = adv_module(data)
        data_buffer.extend(data.reshape(-1))
        for batch in data_buffer:
            optim.zero_grad(set_to_none=True)
            loss = loss_module(batch)
            (loss["loss_objective"] + loss["loss_critic"] + loss["loss_entropy"]).backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
            optim.step()
    collector.update_policy_weights_()
    ep_r = data["next", "episode_reward"][data["next", "done"]]
    if len(ep_r):
        pbar.set_description(f"reward={ep_r.mean().item():.2f}")

collector.shutdown()
env.close()

Key concepts

ConceptClassRole
Environment wrapperGymEnvBridges Gymnasium to TorchRL’s TensorDict API
Preprocessing pipelineTransformedEnvChains transforms like VecNorm, RewardSum
Stochastic policyProbabilisticActorWraps a deterministic net + a distribution
Value networkValueOperatorScalar critic with typed keys
Advantage estimationGAEComputes λ-returns with a value network
Policy lossClipPPOLossClipped surrogate + critic + entropy
Data collectionCollectorIterates the policy in the env, yields batches
Mini-batch samplingTensorDictReplayBufferEpoch-level sampling without replacement

Build docs developers (and LLMs) love