Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Value estimators compute the advantage A(s, a) and the value target V_target(s) from a batch of rollout data. These quantities are the foundation of every policy gradient and actor-critic algorithm — the advantage controls the direction of the policy update, while the value target provides the critic’s training signal. TorchRL provides GAE, TDLambdaEstimator, TD0Estimator, TD1Estimator, VTrace, and MultiAgentGAE, all inheriting from the ValueEstimatorBase abstract class.

Why Value Estimation Matters

The bias-variance tradeoff is central to advantage estimation:
  • Low bias, high variance: Monte Carlo returns (sum of future rewards) are unbiased but have high variance because they depend on the entire future trajectory. Useful when episodes are short or the value function is unreliable.
  • High bias, low variance: 1-step TD bootstraps from the current value estimate, which is biased if the value function is inaccurate, but has low variance because it depends only on one reward step.
  • Interpolation: GAE and TDLambdaEstimator smoothly interpolate between these extremes using a lmbda (λ) parameter. λ = 0 is pure TD(0); λ = 1 is pure Monte Carlo.
All TorchRL estimators accept a TensorDict of shape [*B, T, *F] (batch × time × features), annotate it with "advantage" and "value_target" entries, and return the updated TensorDict. They are fully compatible with the make_value_estimator() method on any LossModule.

ValueEstimatorBase

ValueEstimatorBase is the abstract parent of all estimators. Subclasses must implement forward() and optionally value_estimate(). The key shared behaviours are:
  • Key configuration via set_keys() (same pattern as LossModule).
  • Value network chunking via value_chunk_size or num_chunks for long sequences that don’t fit in memory.
  • Shifted mode (shifted=True) for a memory-efficient single forward pass over T + shifted_budget time steps instead of two passes over T steps.
  • Differentiable mode (differentiable=True) to propagate gradients through the advantage computation (required for some meta-RL setups).

Default Input/Output Keys

KeyDefaultDirection
advantage"advantage"Written to output TensorDict
value_target"value_target"Written to output TensorDict
value"state_value"Read from TensorDict (written by value network)
reward"reward"Read from ("next", "reward")
done"done"Read from ("next", "done")
terminated"terminated"Read from ("next", "terminated")

GAE

Generalized Advantage Estimation (Schulman et al. 2015) is the most widely-used advantage estimator for on-policy algorithms. It computes a weighted sum of multi-step TD residuals:
A_t^GAE(γ,λ) = Σ_{l=0}^{∞} (γλ)^l δ_{t+l}
where δ_t = r_t + γ V(s_{t+1}) − V(s_t) is the 1-step TD error.
from torchrl.objectives.value import GAE

gae = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=False,
    differentiable=False,
    vectorized=True,
)

Constructor Parameters

gamma
float | Tensor
required
Discount factor γ ∈ (0, 1]. Registered as a buffer, so it moves with the module when calling .to(device).
lmbda
float | Tensor
required
GAE λ ∈ [0, 1]. Controls the bias-variance tradeoff:
  • lmbda=0: pure 1-step TD (low variance, high bias)
  • lmbda=1: Monte Carlo returns (high variance, zero bias)
  • lmbda=0.95: typical default, near-MC with reduced variance
value_network
TensorDictModule | None
required
Value operator that writes "state_value" into the TensorDict. If None, the estimator expects pre-computed values already present under the value key.
average_gae
bool
default:"False"
If True, normalises the computed advantage values to zero mean and unit variance across the batch.
differentiable
bool
default:"False"
If True, gradients are propagated through the GAE computation. Required for some meta-RL and differentiable planning setups. The typical way to disable gradients otherwise is to wrap the call in torch.no_grad().
vectorized
bool | None
default:"None"
Whether to use the vectorized (vmap-based) implementation of the λ-return recursion. Defaults to True unless torch.compile is active, in which case it falls back to False automatically.
shifted
bool
default:"False"
If True, evaluates the value network once over T + shifted_budget time steps instead of twice over T steps. More memory-efficient for long sequences. Requires the standard one-step rollout layout (obs[t+1] equals next_obs[t] for non-reset steps). Incompatible with multi-step bootstrapping.
shifted_budget
int
default:"1"
Number of extra time slots when shifted=True. Use 2 if your rollout contains internal resets (truncated episodes) so their next-observations can be inserted without displacing other samples.
skip_existing
bool | None
default:"None"
If True, the value network skips modules whose outputs are already present in the TensorDict. None defers to the global tensordict.nn.skip_existing() setting.
time_dim
int | None
default:"None"
Dimension of the time axis in the input TensorDict. If None, uses the dimension labelled "time" in the TensorDict’s names, or the last dimension as a fallback. Can be overridden per-call via gae(data, time_dim=1).
value_chunk_size
int | None
default:"None"
Splits value-network calls into chunks of this size along the leading batch dimension. Useful for very large batches on memory-constrained hardware. Mutually exclusive with num_chunks.
num_chunks
int | None
default:"None"
Splits value-network calls into this many chunks. Mutually exclusive with value_chunk_size.
deactivate_vmap
bool
default:"False"
If True, replaces vmap calls with plain Python for-loops. Required when using RNN-based value networks (which are not compatible with torch.vmap).

Output Keys Written to TensorDict

KeyShapeDescription
"advantage"[*B, T, 1]GAE advantage estimates
"value_target"[*B, T, 1]Value function training targets

GAE Usage Example

import torch
from torch import nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.objectives.value import GAE

value_net = TensorDictModule(
    nn.Linear(3, 1),
    in_keys=["obs"],
    out_keys=["state_value"],
)

gae = GAE(gamma=0.99, lmbda=0.95, value_network=value_net)

# Input: rollout batch of shape [batch_size, time_steps]
obs, next_obs = torch.randn(2, 4, 10, 3)
reward      = torch.randn(4, 10, 1)
done        = torch.zeros(4, 10, 1, dtype=torch.bool)
terminated  = torch.zeros(4, 10, 1, dtype=torch.bool)

tensordict = TensorDict({
    "obs":  obs,
    "next": {
        "obs":        next_obs,
        "reward":     reward,
        "done":       done,
        "terminated": terminated,
    },
}, batch_size=[4, 10])

gae(tensordict)

assert "advantage"    in tensordict.keys()
assert "value_target" in tensordict.keys()
print(tensordict["advantage"].shape)   # torch.Size([4, 10, 1])

GAE with ClipPPOLoss

from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

# Standalone GAE (compute before the inner update loop)
gae = GAE(gamma=0.99, lmbda=0.95, value_network=critic)

loss_module = ClipPPOLoss(actor, critic, clip_epsilon=0.2)

# ---- Training loop ----
for batch in collector:
    # 1. Compute advantages once
    with torch.no_grad():
        gae(batch)

    # 2. Multiple gradient steps on the same batch (PPO's inner loop)
    for _ in range(4):
        loss_td = loss_module(batch)
        loss = loss_td["loss_objective"] + loss_td["loss_critic"]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
Call gae(batch) once per rollout, outside the inner PPO update loop. This is both faster (one value-network forward pass instead of K) and more accurate (advantages remain consistent with the behavior policy throughout the K gradient steps).

TDLambdaEstimator

TDLambdaEstimator computes TD(λ) returns — a geometric mixture of n-step returns for n = 1, 2, …. Unlike GAE which uses per-step TD errors, TDLambdaEstimator uses the full multi-step bootstrapped returns:
G_t^λ = (1 − λ) Σ_{n=1}^{∞} λ^{n-1} G_t^{(n)}
from torchrl.objectives.value import TDLambdaEstimator

estimator = TDLambdaEstimator(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    vectorized=True,
)

Constructor Parameters

gamma
float | Tensor
required
Discount factor.
lmbda
float | Tensor
required
TD(λ) mixing parameter. lmbda=0 → 1-step TD; lmbda=1 → Monte Carlo.
value_network
TensorDictModule
required
Value operator producing "state_value".
vectorized
bool
default:"True"
Use the vmap-parallelised scan implementation. Set to False for RNNs or when torch.compile is active.
average_rewards
bool
default:"False"
If True, normalises rewards before TD computation.
Output keys: same as GAE — "advantage" and "value_target".

GAE vs. TDLambdaEstimator

Both compute similar weighted multi-step estimates. The differences are:
GAETDLambdaEstimator
FormulaWeighted TD errors δ_tWeighted n-step returns G^(n)
EquivalenceEqual when value function is accurateEqual to GAE in the limit
Typical usePPO, A2CDRL textbooks, legacy codebases
PerformanceIdentical in practiceIdentical in practice

TD0Estimator

TD0Estimator computes 1-step bootstrapped TD targets:
V_target(s_t) = r_t + γ (1 − done_t) V(s_{t+1})
A(s_t, a_t) = V_target(s_t) − V(s_t)
from torchrl.objectives.value import TD0Estimator

estimator = TD0Estimator(
    gamma=0.99,
    value_network=critic,
    average_rewards=False,
    differentiable=False,
)

Constructor Parameters

gamma
float | Tensor
required
Discount factor.
value_network
TensorDictModule
required
Value operator.
average_rewards
bool
default:"False"
Normalise rewards before the TD computation.
shifted
bool
default:"False"
Single-call shifted backend (see GAE docs for details).
TD0Estimator is the default value estimator for DQNLoss. Loss modules that do not need the full λ-weighted return (e.g. Q-learning variants that compute their own targets) will use TD0 by default.

TD1Estimator

TD1Estimator computes the full Monte Carlo (∞-step TD) return:
G_t = Σ_{l=0}^{T-t} γ^l r_{t+l}
This is the limiting case λ = 1 of TDLambdaEstimator.
from torchrl.objectives.value import TD1Estimator

estimator = TD1Estimator(gamma=0.99, value_network=critic)
Monte Carlo returns have high variance. Use TD1Estimator only for short-horizon tasks or when the value function is not yet reliable enough for bootstrapping.

VTrace

V-Trace (Espeholt et al. 2018, IMPALA) is an off-policy advantage estimator designed for distributed RL with asynchronous data collection. It corrects for the discrepancy between the behaviour policy (which collected the data) and the learning policy (which is being updated) using clipped importance weights:
v_s = V(x_s) + Σ_{t=s}^{s+n-1} γ^{t-s} (Π_{i=s}^{t-1} c_i) δ_t^V
where c_i = min(c̄, π(a_i|x_i) / μ(a_i|x_i)) and ρ_t = min(ρ̄, π/μ) are clipped importance weights.
from torchrl.objectives.value import VTrace

vtrace = VTrace(
    gamma=0.99,
    value_network=critic,
    actor_network=actor,
    rho_thresh=1.0,
    c_thresh=1.0,
    average_adv=False,
    differentiable=True,
)

Constructor Parameters

gamma
float | Tensor
required
Discount factor.
value_network
TensorDictModule
required
State value operator V(s).
actor_network
TensorDictModule
required
Current (learning) policy. Used to compute log-probabilities of the collected actions under the current policy.
rho_thresh
float | Tensor
default:"1.0"
Clipping threshold ρ̄ for the IS weight used in the value target computation. Larger values reduce bias at the cost of higher variance. The IMPALA paper uses rho_thresh=1.0.
c_thresh
float | Tensor
default:"1.0"
Clipping threshold c̄ for the IS weight used in the multi-step bootstrapping trace. Also 1.0 in the original paper.
average_adv
bool
default:"False"
Normalise resulting advantage values across the batch.
differentiable
bool
default:"False"
Propagate gradients through the V-trace computation.
shifted
bool
default:"False"
Single-call shifted backend (see GAE docs for details).

VTrace vs. GAE

GAEVTrace
Policy requirementOn-policyOff-policy
Importance weightsNone (or implicit)Explicit, clipped (ρ, c)
Typical settingPPO / A2CIMPALA, distributed RL
Bias correctionNot needed (on-policy)Needed (stale batches)
from torchrl.objectives.value import VTrace
from torchrl.objectives.utils import ValueEstimators

# Switch a loss module to VTrace
loss_module.make_value_estimator(
    ValueEstimators.VTrace,
    gamma=0.99,
    rho_thresh=1.0,
    c_thresh=1.0,
)

MultiAgentGAE

MultiAgentGAE extends GAE for cooperative MARL, where the value network outputs per-agent estimates of shape [*B, T, n_agents, 1] but the reward and done signals are team-shared with shape [*B, T, 1]. It automatically broadcasts the team signals to the agent dimension before running the standard GAE recursion.
from torchrl.objectives.value import MultiAgentGAE

gae = MultiAgentGAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=False,
    agent_dim=-2,   # penultimate dimension holds agents
)

# Writes ("agents", "advantage") and ("agents", "value_target")
gae(data)
agent_dim
int
default:"-2"
Dimension holding the agent index in the value tensor. Defaults to -2 (penultimate), consistent with MultiAgentMLP’s output convention.
All other parameters are forwarded to GAE unchanged.

Per-Agent Advantage Normalization

MultiAgentGAE normalises advantages per-agent independently (reducing over batch + time but not the agent dimension). This prevents high-variance agents from dominating the gradient signal:
gae = MultiAgentGAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=True,   # enables per-agent normalisation
)

make_value_estimator()

Every LossModule exposes make_value_estimator() for swapping the advantage estimator at runtime. The method accepts a ValueEstimators enum value and any keyword arguments the estimator constructor accepts:
from torchrl.objectives.utils import ValueEstimators

# Available estimators
ValueEstimators.GAE          # GAE (default for PPO, A2C)
ValueEstimators.TDLambda     # TDLambdaEstimator
ValueEstimators.TD0          # TD0Estimator (default for DQN)
ValueEstimators.TD1          # TD1Estimator
ValueEstimators.VTrace       # VTrace (off-policy)
ValueEstimators.MAGAE        # MultiAgentGAE (default for MAPPO/IPPO)

loss_module.make_value_estimator(
    ValueEstimators.GAE,
    gamma=0.99,
    lmbda=0.95,
)
You can also construct the estimator manually and attach it:
from torchrl.objectives.value import GAE

gae = GAE(gamma=0.99, lmbda=0.95, value_network=critic)

# Use it standalone (pre-compute before the loss)
gae(batch)
loss_td = loss_module(batch)   # reads pre-computed advantage

# Or inject it so the loss computes advantages on-the-fly when the key is absent
loss_module.value_estimator = gae

Shifted and Compact Estimator Variants

Starting from TorchRL 0.13, all estimators support shifted=True for a memory-efficient single-call path. In standard mode (shifted=False) the value network is called twice per update — once on the current observations and once on the next observations. With shifted=True, a single call is made over a fused [T + shifted_budget]-length sequence:
gae = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    shifted=True,
    shifted_budget=1,   # 1 extra slot handles rollout boundaries
)
shifted=True requires that obs[t+1] equals next_obs[t] for all non-reset steps (the standard single-step rollout invariant). It is incompatible with multi-step return processing (n-step bootstrapping) where next_obs[t] is set to obs[t+n].
Use shifted_budget=2 when your rollout contains internal resets (truncated episodes that auto-reset mid-batch) so their true next-observations can be inserted without displacing other samples.

Choosing a Value Estimator

Build docs developers (and LLMs) love