Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL loss modules are nn.Module subclasses that read a TensorDict, compute one or more differentiable losses, and write them back under keys prefixed with "loss_". Unlike monolithic loss functions, each component loss is exposed separately so optimizers, loggers, and schedulers can work on individual terms. Key mappings — which TensorDict field each part of the loss reads — are configurable at construction time, making it straightforward to adapt a loss to a custom environment schema without modifying the algorithm itself.

LossModule: the base class

Every TorchRL loss inherits from LossModule. It extends TensorDictModuleBase and adds:
  • Named loss keys — all outputs prefixed with "loss_" are collected automatically.
  • Configurable key mappings — every loss declares an _AcceptedKeys dataclass; call loss.set_keys(action="my_action") to remap without subclassing.
  • Target network managementSoftUpdate / HardUpdate polyak-average or copy target parameters.
  • Value estimator registry — each loss declares a default_value_estimator; call loss.make_value_estimator(ValueEstimators.GAE, gamma=0.99) to swap it out.
  • Schedulable buffers — scalar coefficients like entropy_coeff can be updated with direct assignment and are tracked as proper nn.Module buffers.
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import SoftUpdate, HardUpdate, ValueEstimators

# Remap tensordict keys to match your environment schema.
loss.set_keys(
    action="my_action",
    observation="obs",
    reward="extrinsic_reward",
)

# Swap the value estimator.
loss.make_value_estimator(ValueEstimators.TDLambda, gamma=0.99, lmbda=0.95)

# Attach a target network updater (for off-policy losses).
updater = SoftUpdate(loss, eps=0.005)
updater.step()   # call after each gradient step

Policy-gradient losses

ClipPPOLoss

The clipped surrogate objective from Proximal Policy Optimization. Computes importance-weighted advantage clipping, value loss, and entropy bonus.
from torchrl.objectives import ClipPPOLoss

loss_fn = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=0.2,         # symmetric clip range [1-ε, 1+ε]
    entropy_bonus=True,
    entropy_coeff=0.01,
    critic_coeff=1.0,
    loss_critic_type="smooth_l1",
    normalize_advantage=False,
)

losses = loss_fn(batch)
total = (
    losses["loss_objective"]   # clipped surrogate
    + losses["loss_critic"]    # value function loss
    + losses["loss_entropy"]   # entropy bonus
)
total.backward()
ClipPPOLoss also reports "clip_fraction" and "ESS" (effective sample size) as non-differentiable scalars that help monitor training stability.

KLPENPPOLoss

An alternative PPO formulation using KL-divergence penalty instead of clipping, with adaptive penalty coefficient. Useful when the clip surrogate is too conservative.
from torchrl.objectives import KLPENPPOLoss

loss_fn = KLPENPPOLoss(
    actor_network=actor,
    critic_network=critic,
    dtarg=0.01,       # target KL divergence
    beta=1.0,         # initial KL penalty coefficient
)

ReinforceLoss

The classic REINFORCE (vanilla policy gradient) loss. Lightweight; useful as a baseline or for discrete-action tasks.
from torchrl.objectives import ReinforceLoss

loss_fn = ReinforceLoss(
    actor_network=actor,
    critic_network=critic,     # optional baseline
    gamma=0.99,
    advantage_key="advantage",
)
losses = loss_fn(batch)
losses["loss_actor"].backward()

A2CLoss

Advantage Actor-Critic loss, combining policy gradient, value, and entropy terms. Suitable for synchronous parallel environments.
from torchrl.objectives import A2CLoss

loss_fn = A2CLoss(
    actor_network=actor,
    critic_network=critic,
    entropy_bonus=True,
    entropy_coeff=0.01,
    critic_coeff=0.5,
)

Actor-critic off-policy losses

SACLoss

Soft Actor-Critic combines a stochastic actor, twin Q-networks, and a learnable entropy temperature. Returns "loss_actor", "loss_qvalue", and "loss_alpha".
from torchrl.objectives import SACLoss
from torchrl.objectives.utils import SoftUpdate

loss_fn = SACLoss(
    actor_network=actor,
    qvalue_network=qvalue,      # or a list of 2 Q-networks
    num_qvalue_nets=2,
    alpha_init=1.0,
    target_entropy="auto",      # −dim(action)
    loss_function="smooth_l1",
    delay_qvalue=True,          # use target Q-networks
)
target_updater = SoftUpdate(loss_fn, eps=0.005)

losses = loss_fn(batch)
# losses["loss_actor"]   — maximize entropy-regularized Q
# losses["loss_qvalue"]  — Bellman backup
# losses["loss_alpha"]   — temperature tuning

TD3Loss

Twin Delayed DDPG (TD3). Adds target policy smoothing and delayed actor updates to DDPGLoss.
from torchrl.objectives import TD3Loss

loss_fn = TD3Loss(
    actor_network=actor,
    qvalue_network=qvalue,
    action_spec=env.action_spec,
    policy_noise=0.2,          # target policy smoothing noise
    noise_clip=0.5,
    delay_actor=2,             # update actor every 2 critic steps
    loss_function="smooth_l1",
)

DDPGLoss

Deep Deterministic Policy Gradient. A deterministic actor trained with Q-value gradients.
from torchrl.objectives import DDPGLoss

loss_fn = DDPGLoss(
    actor_network=actor,
    value_network=qvalue,
    loss_function="smooth_l1",
)

Q-learning losses

DQNLoss

Standard DQN with optional double Q-learning. Compatible with QValueActor for ε-greedy action selection.
from torchrl.objectives import DQNLoss
from torchrl.objectives.utils import HardUpdate

loss_fn = DQNLoss(
    value_network=q_net,
    action_space=env.action_spec,
    loss_function="l2",
    double_dqn=True,
    delay_value=True,
)
target_updater = HardUpdate(loss_fn, value_network_update_interval=1000)

DistributionalDQNLoss

Distributional DQN (C51). Computes cross-entropy between projected target distribution and predicted atom probabilities.
from torchrl.objectives import DistributionalDQNLoss

loss_fn = DistributionalDQNLoss(
    value_network=distributional_q_net,
    gamma=0.99,
)

CQLLoss and IQLLoss

Offline Q-learning objectives from Conservative Q-Learning and Implicit Q-Learning.
from torchrl.objectives import CQLLoss, IQLLoss

# Conservative Q-Learning: penalizes OOD actions.
cql = CQLLoss(
    actor_network=actor,
    qvalue_network=qvalue,
    alpha_init=1.0,
)

# Implicit Q-Learning: expectile regression, no policy optimization needed.
iql = IQLLoss(
    actor_network=actor,
    qvalue_network=qvalue,
    value_network=value_net,
    expectile=0.7,
    loss_function="l2",
)

Value estimators

Value estimators compute advantages and return targets from raw trajectory data and populate the TensorDict with "advantage" and "value_target" keys before the loss is called.
from torchrl.objectives.value import GAE

advantage_fn = GAE(
    value_network=critic,
    gamma=0.99,
    lmbda=0.95,
    average_gae=False,
)
batch = advantage_fn(batch)
# batch["advantage"]    — GAE advantage
# batch["value_target"] — λ-return target
Generalized Advantage Estimation. Standard choice for PPO and A2C.

Attaching an estimator to a loss module

Every loss that needs value estimation exposes make_value_estimator():
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.utils import ValueEstimators

loss_fn = ClipPPOLoss(actor_network=actor, critic_network=critic)
loss_fn.make_value_estimator(
    ValueEstimators.GAE,
    gamma=0.99,
    lmbda=0.95,
)

Multi-agent losses

MAPPOLoss

Multi-Agent PPO with a centralized critic. Reads per-agent observations under ("agents", "observation") and expects a group_map to identify agent groups.

IPPOLoss

Independent PPO: each agent has its own actor and critic; no shared parameters across agents.

QMixerLoss

QMIX / VDN monotonic mixing network that combines per-agent Q-values into a joint Q-value for cooperative tasks.
from torchrl.objectives.multiagent import MAPPOLoss
from torchrl.objectives.value import MultiAgentGAE

mappo_loss = MAPPOLoss(
    actor_network=actor,
    critic_network=centralized_critic,
    clip_epsilon=0.2,
    entropy_coeff=0.01,
)
advantage_fn = MultiAgentGAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=centralized_critic,
    agent_dim=-2,
)

Target network utilities

from torchrl.objectives.utils import SoftUpdate, HardUpdate, TargetNetUpdater

# Soft (Polyak) update: θ_target ← (1 - ε) * θ_target + ε * θ
soft = SoftUpdate(loss_fn, eps=0.005)
soft.step()   # call once per gradient step

# Hard update: copy weights every N steps.
hard = HardUpdate(loss_fn, value_network_update_interval=1000)
hard.step()   # increments internal counter; copies on the N-th call

Complete PPO example with GAE

The following example assembles a full PPO pipeline: environment, actor, critic, ClipPPOLoss, GAE, and an optimizer.
import torch
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torchrl.collectors import Collector
from torchrl.envs import GymEnv, TransformedEnv, ObservationNorm, Compose
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# --- Environment ---
env = TransformedEnv(
    GymEnv("HalfCheetah-v4"),
    Compose(ObservationNorm(in_keys=["observation"])),
)
obs_dim = env.observation_spec["observation"].shape[-1]
action_dim = env.action_spec.shape[-1]

# --- Actor ---
actor = ProbabilisticActor(
    TensorDictModule(
        nn.Sequential(
            MLP(obs_dim, 2 * action_dim, num_cells=[256, 256]),
            NormalParamExtractor(),
        ),
        in_keys=["observation"],
        out_keys=["loc", "scale"],
    ),
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=TanhNormal,
    distribution_kwargs={"low": -1.0, "high": 1.0},
    return_log_prob=True,
)

# --- Critic ---
critic = ValueOperator(
    MLP(obs_dim, 1, num_cells=[256, 256]),
    in_keys=["observation"],
    out_keys=["state_value"],
)

# --- Loss and advantage ---
loss_fn = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=0.2,
    entropy_bonus=True,
    entropy_coeff=0.01,
    critic_coeff=0.5,
    normalize_advantage=True,
)
advantage_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95)
optimizer = torch.optim.Adam(loss_fn.parameters(), lr=3e-4)

# --- Collector ---
collector = Collector(
    create_env_fn=env,
    policy=actor,
    frames_per_batch=2048,
    total_frames=500_000,
)

# --- Training loop ---
for batch in collector:
    # 1. Compute advantages in-place.
    with torch.no_grad():
        batch = advantage_fn(batch)

    # 2. Multiple epochs of mini-batch updates.
    for _ in range(10):
        idx = torch.randperm(len(batch))[:256]
        mini = batch[idx]
        losses = loss_fn(mini)
        total_loss = (
            losses["loss_objective"]
            + losses["loss_critic"]
            + losses["loss_entropy"]
        )
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), 0.5)
        optimizer.step()

collector.shutdown()

Additional loss modules

LossAlgorithm
REDQLossRandomized Ensembled Double Q-Learning
CrossQLossCrossQ (off-policy actor-critic without target delay)
TD3BCLossTD3+BC offline RL
BCLossBehavior Cloning
GAILLossGenerative Adversarial Imitation Learning
RNDLossRandom Network Distillation (intrinsic motivation)
DreamerActorLoss, DreamerModelLoss, DreamerValueLossDreamer model-based RL
DreamerV3ActorLoss, DreamerV3ModelLoss, DreamerV3ValueLossDreamerV3
DTLoss, OnlineDTLossDecision Transformer
ACTLossAction Chunking Transformer
DiffusionBCLossDiffusion-based behavior cloning
WorldModelLossGeneric world-model loss

Build docs developers (and LLMs) love