TorchRL structures every RL algorithm around the same composable building blocks: a typed environment, a policy and value network expressed asDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
TensorDictModules, a data Collector that streams batches, and a loss module that consumes those batches. This tutorial walks you through assembling all of those pieces into a working Proximal Policy Optimization (PPO) loop for a continuous-control MuJoCo task. The code shown here is drawn directly from the SOTA implementation in the TorchRL repository.
GPU training is optional. All examples below run on CPU; switch
device="cuda" if a GPU is available.TorchRL wraps Gymnasium environments with
GymEnv and chains preprocessing steps via TransformedEnv. Each transform is appended in order and runs at every step.import torch
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
GymEnv,
RewardSum,
StepCounter,
TransformedEnv,
VecNorm,
)
def make_env(env_name: str = "HalfCheetah-v4", device: str = "cpu") -> TransformedEnv:
# GymEnv wraps a Gymnasium environment and returns TensorDicts
env = GymEnv(env_name, device=device)
env = TransformedEnv(env)
# VecNorm normalises observations with a running exponential mean/variance
env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
# RewardSum accumulates episode returns under "episode_reward"
env.append_transform(RewardSum())
# StepCounter adds a "step_count" key to every transition
env.append_transform(StepCounter())
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env
env = make_env()
print(env.observation_spec)
print(env.action_spec)
The environment’s
observation_spec and action_spec describe the shape, dtype, and bounds of every tensor the environment reads or writes. These specs are used in later steps to automatically size network inputs and outputs.PPO uses a
ProbabilisticActor: a wrapper that takes a deterministic network producing distribution parameters and adds stochastic sampling. For continuous actions we use a TanhNormal distribution, which squashes samples into the action bounds.import torch.nn as nn
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.envs import ExplorationType
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal
def make_actor(env: TransformedEnv, device: str = "cpu") -> ProbabilisticActor:
obs_shape = env.observation_spec["observation"].shape
num_outputs = env.action_spec_unbatched.shape[-1]
# The MLP predicts only the location (mean) of the Gaussian
policy_mlp = MLP(
in_features=obs_shape[-1],
activation_class=nn.Tanh,
out_features=num_outputs,
num_cells=[64, 64],
device=device,
)
# Orthogonal weight initialisation (standard for PPO)
for layer in policy_mlp.modules():
if isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, 1.0)
layer.bias.data.zero_()
# AddStateIndependentNormalScale appends a learnable log-std parameter
policy_mlp = nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(num_outputs, scale_lb=1e-8).to(device),
)
# TensorDictModule connects the network to TensorDict keys
actor_module = TensorDictModule(
module=policy_mlp,
in_keys=["observation"],
out_keys=["loc", "scale"],
)
# ProbabilisticActor wraps the module and attaches the distribution
actor = ProbabilisticActor(
actor_module,
in_keys=["loc", "scale"],
spec=env.full_action_spec_unbatched.to(device),
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec_unbatched.space.low.to(device),
"high": env.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
},
return_log_prob=True, # needed for PPO importance weights
default_interaction_type=ExplorationType.RANDOM,
)
return actor
return_log_prob=True tells ProbabilisticActor to write the log-probability of the sampled action under the key "sample_log_prob". ClipPPOLoss reads this key automatically.The critic predicts a scalar state value.
ValueOperator is a thin wrapper around an nn.Module that registers the correct in/out keys for downstream loss modules.from torchrl.modules import ValueOperator
def make_critic(env: TransformedEnv, device: str = "cpu") -> ValueOperator:
obs_shape = env.observation_spec["observation"].shape
value_mlp = MLP(
in_features=obs_shape[-1],
activation_class=nn.Tanh,
out_features=1,
num_cells=[64, 64],
device=device,
)
for layer in value_mlp.modules():
if isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, 0.01)
layer.bias.data.zero_()
critic = ValueOperator(value_mlp, in_keys=["observation"])
return critic
GAE (Generalized Advantage Estimation) computes advantage and value targets from a batch of transitions. ClipPPOLoss implements the clipped surrogate objective together with a critic loss and an entropy bonus.from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
device = "cpu"
actor = make_actor(env, device)
critic = make_critic(env, device)
# GAE wraps the value network and is applied with torch.no_grad() at training time
adv_module = GAE(
gamma=0.99,
lmbda=0.95,
value_network=critic,
average_gae=False,
device=device,
)
loss_module = ClipPPOLoss(
actor_network=actor,
critic_network=critic,
clip_epsilon=0.2,
loss_critic_type="smooth_l1",
entropy_coeff=0.01,
critic_coeff=0.5,
normalize_advantage=True,
)
# group_optimizers produces a single optimizer from two; useful for LR scheduling
actor_optim = torch.optim.Adam(actor.parameters(), lr=3e-4, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=3e-4, eps=1e-5)
optim = group_optimizers(actor_optim, critic_optim)
ClipPPOLoss looks up keys like "observation", "action", "sample_log_prob", and "advantage" automatically when they match the defaults. Use loss_module.set_keys(...) to override them for non-standard environments.Collector runs the policy in the environment and yields TensorDict batches. It handles device placement, auto-reset between episodes, and optional torch.compile acceleration.from torchrl.collectors import Collector
FRAMES_PER_BATCH = 2048
TOTAL_FRAMES = 1_000_000
collector = Collector(
create_env_fn=make_env, # called once to spin up the env
policy=actor,
frames_per_batch=FRAMES_PER_BATCH,
total_frames=TOTAL_FRAMES,
device=device,
max_frames_per_traj=-1, # -1 means no forced resets
)
PPO reuses each collected batch for several epochs of gradient updates. A
TensorDictReplayBuffer with SamplerWithoutReplacement provides epoch-level mini-batching without repetition.from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
MINI_BATCH_SIZE = 256
data_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(FRAMES_PER_BATCH, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=MINI_BATCH_SIZE,
)
The outer loop iterates over the collector; the inner loop performs multiple epochs of PPO updates on the collected batch.
import tqdm
from torchrl.envs import set_exploration_type, ExplorationType
PPO_EPOCHS = 10
NUM_MINI_BATCHES = FRAMES_PER_BATCH // MINI_BATCH_SIZE
pbar = tqdm.tqdm(total=TOTAL_FRAMES)
collected_frames = 0
for data in collector:
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(frames_in_batch)
for _ in range(PPO_EPOCHS):
# Compute advantages and value targets (no gradients needed here)
with torch.no_grad():
data = adv_module(data)
# Fill the replay buffer with the flattened batch
data_buffer.extend(data.reshape(-1))
for batch in data_buffer:
optim.zero_grad(set_to_none=True)
loss = loss_module(batch)
# Sum the three PPO loss terms
total_loss = (
loss["loss_objective"]
+ loss["loss_critic"]
+ loss["loss_entropy"]
)
total_loss.backward()
# Gradient clipping is standard practice for PPO
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=0.5)
optim.step()
# Push the updated weights back to the collector's copy of the policy
collector.update_policy_weights_()
# Log episode reward (only non-NaN entries correspond to completed episodes)
ep_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(ep_rewards):
pbar.set_description(f"reward={ep_rewards.mean().item():.2f}")
collector.shutdown()
env.close()
test_env = make_env(device=device)
test_env.eval()
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
actor.eval()
td = test_env.rollout(
policy=actor,
max_steps=1000,
auto_reset=True,
break_when_any_done=True,
)
episode_reward = td["next", "episode_reward"][td["next", "done"]]
print(f"Test reward: {episode_reward.mean().item():.2f}")
test_env.close()
Putting it all together
Key concepts
| Concept | Class | Role |
|---|---|---|
| Environment wrapper | GymEnv | Bridges Gymnasium to TorchRL’s TensorDict API |
| Preprocessing pipeline | TransformedEnv | Chains transforms like VecNorm, RewardSum |
| Stochastic policy | ProbabilisticActor | Wraps a deterministic net + a distribution |
| Value network | ValueOperator | Scalar critic with typed keys |
| Advantage estimation | GAE | Computes λ-returns with a value network |
| Policy loss | ClipPPOLoss | Clipped surrogate + critic + entropy |
| Data collection | Collector | Iterates the policy in the env, yields batches |
| Mini-batch sampling | TensorDictReplayBuffer | Epoch-level sampling without replacement |