Value estimators compute the advantageDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
A(s, a) and the value target V_target(s) from a batch of rollout data. These quantities are the foundation of every policy gradient and actor-critic algorithm — the advantage controls the direction of the policy update, while the value target provides the critic’s training signal. TorchRL provides GAE, TDLambdaEstimator, TD0Estimator, TD1Estimator, VTrace, and MultiAgentGAE, all inheriting from the ValueEstimatorBase abstract class.
Why Value Estimation Matters
The bias-variance tradeoff is central to advantage estimation:- Low bias, high variance: Monte Carlo returns (sum of future rewards) are unbiased but have high variance because they depend on the entire future trajectory. Useful when episodes are short or the value function is unreliable.
- High bias, low variance: 1-step TD bootstraps from the current value estimate, which is biased if the value function is inaccurate, but has low variance because it depends only on one reward step.
- Interpolation:
GAEandTDLambdaEstimatorsmoothly interpolate between these extremes using almbda(λ) parameter.λ = 0is pure TD(0);λ = 1is pure Monte Carlo.
TensorDict of shape [*B, T, *F] (batch × time × features), annotate it with "advantage" and "value_target" entries, and return the updated TensorDict. They are fully compatible with the make_value_estimator() method on any LossModule.
ValueEstimatorBase
ValueEstimatorBase is the abstract parent of all estimators. Subclasses must implement forward() and optionally value_estimate(). The key shared behaviours are:
- Key configuration via
set_keys()(same pattern asLossModule). - Value network chunking via
value_chunk_sizeornum_chunksfor long sequences that don’t fit in memory. - Shifted mode (
shifted=True) for a memory-efficient single forward pass overT + shifted_budgettime steps instead of two passes overTsteps. - Differentiable mode (
differentiable=True) to propagate gradients through the advantage computation (required for some meta-RL setups).
Default Input/Output Keys
| Key | Default | Direction |
|---|---|---|
advantage | "advantage" | Written to output TensorDict |
value_target | "value_target" | Written to output TensorDict |
value | "state_value" | Read from TensorDict (written by value network) |
reward | "reward" | Read from ("next", "reward") |
done | "done" | Read from ("next", "done") |
terminated | "terminated" | Read from ("next", "terminated") |
GAE
Generalized Advantage Estimation (Schulman et al. 2015) is the most widely-used advantage estimator for on-policy algorithms. It computes a weighted sum of multi-step TD residuals:δ_t = r_t + γ V(s_{t+1}) − V(s_t) is the 1-step TD error.
Constructor Parameters
Discount factor γ ∈ (0, 1]. Registered as a buffer, so it moves with the
module when calling
.to(device).GAE λ ∈ [0, 1]. Controls the bias-variance tradeoff:
lmbda=0: pure 1-step TD (low variance, high bias)lmbda=1: Monte Carlo returns (high variance, zero bias)lmbda=0.95: typical default, near-MC with reduced variance
Value operator that writes
"state_value" into the TensorDict. If None,
the estimator expects pre-computed values already present under the
value key.If
True, normalises the computed advantage values to zero mean and unit
variance across the batch.If
True, gradients are propagated through the GAE computation. Required
for some meta-RL and differentiable planning setups. The typical way to
disable gradients otherwise is to wrap the call in torch.no_grad().Whether to use the vectorized (vmap-based) implementation of the λ-return
recursion. Defaults to
True unless torch.compile is active, in which
case it falls back to False automatically.If
True, evaluates the value network once over T + shifted_budget time
steps instead of twice over T steps. More memory-efficient for long
sequences. Requires the standard one-step rollout layout (obs[t+1] equals
next_obs[t] for non-reset steps). Incompatible with multi-step bootstrapping.Number of extra time slots when
shifted=True. Use 2 if your rollout
contains internal resets (truncated episodes) so their next-observations
can be inserted without displacing other samples.If
True, the value network skips modules whose outputs are already present
in the TensorDict. None defers to the global tensordict.nn.skip_existing()
setting.Dimension of the time axis in the input TensorDict. If
None, uses the
dimension labelled "time" in the TensorDict’s names, or the last dimension
as a fallback. Can be overridden per-call via gae(data, time_dim=1).Splits value-network calls into chunks of this size along the leading batch
dimension. Useful for very large batches on memory-constrained hardware.
Mutually exclusive with
num_chunks.Splits value-network calls into this many chunks. Mutually exclusive with
value_chunk_size.If
True, replaces vmap calls with plain Python for-loops. Required when
using RNN-based value networks (which are not compatible with torch.vmap).Output Keys Written to TensorDict
| Key | Shape | Description |
|---|---|---|
"advantage" | [*B, T, 1] | GAE advantage estimates |
"value_target" | [*B, T, 1] | Value function training targets |
GAE Usage Example
GAE with ClipPPOLoss
TDLambdaEstimator
TDLambdaEstimator computes TD(λ) returns — a geometric mixture of n-step returns for n = 1, 2, …. Unlike GAE which uses per-step TD errors, TDLambdaEstimator uses the full multi-step bootstrapped returns:
Constructor Parameters
Discount factor.
TD(λ) mixing parameter.
lmbda=0 → 1-step TD; lmbda=1 → Monte Carlo.Value operator producing
"state_value".Use the vmap-parallelised scan implementation. Set to
False for RNNs or
when torch.compile is active.If
True, normalises rewards before TD computation."advantage" and "value_target".
GAE vs. TDLambdaEstimator
Both compute similar weighted multi-step estimates. The differences are:| GAE | TDLambdaEstimator | |
|---|---|---|
| Formula | Weighted TD errors δ_t | Weighted n-step returns G^(n) |
| Equivalence | Equal when value function is accurate | Equal to GAE in the limit |
| Typical use | PPO, A2C | DRL textbooks, legacy codebases |
| Performance | Identical in practice | Identical in practice |
TD0Estimator
TD0Estimator computes 1-step bootstrapped TD targets:
Constructor Parameters
Discount factor.
Value operator.
Normalise rewards before the TD computation.
Single-call shifted backend (see GAE docs for details).
TD0Estimator is the default value estimator for DQNLoss. Loss modules that do not need the full λ-weighted return (e.g. Q-learning variants that compute their own targets) will use TD0 by default.TD1Estimator
TD1Estimator computes the full Monte Carlo (∞-step TD) return:
λ = 1 of TDLambdaEstimator.
VTrace
V-Trace (Espeholt et al. 2018, IMPALA) is an off-policy advantage estimator designed for distributed RL with asynchronous data collection. It corrects for the discrepancy between the behaviour policy (which collected the data) and the learning policy (which is being updated) using clipped importance weights:c_i = min(c̄, π(a_i|x_i) / μ(a_i|x_i)) and ρ_t = min(ρ̄, π/μ) are clipped importance weights.
Constructor Parameters
Discount factor.
State value operator V(s).
Current (learning) policy. Used to compute log-probabilities of the
collected actions under the current policy.
Clipping threshold ρ̄ for the IS weight used in the value target computation.
Larger values reduce bias at the cost of higher variance. The IMPALA paper
uses
rho_thresh=1.0.Clipping threshold c̄ for the IS weight used in the multi-step bootstrapping
trace. Also
1.0 in the original paper.Normalise resulting advantage values across the batch.
Propagate gradients through the V-trace computation.
Single-call shifted backend (see GAE docs for details).
VTrace vs. GAE
| GAE | VTrace | |
|---|---|---|
| Policy requirement | On-policy | Off-policy |
| Importance weights | None (or implicit) | Explicit, clipped (ρ, c) |
| Typical setting | PPO / A2C | IMPALA, distributed RL |
| Bias correction | Not needed (on-policy) | Needed (stale batches) |
MultiAgentGAE
MultiAgentGAE extends GAE for cooperative MARL, where the value network outputs per-agent estimates of shape [*B, T, n_agents, 1] but the reward and done signals are team-shared with shape [*B, T, 1]. It automatically broadcasts the team signals to the agent dimension before running the standard GAE recursion.
Dimension holding the agent index in the value tensor. Defaults to
-2
(penultimate), consistent with MultiAgentMLP’s output convention.GAE unchanged.
Per-Agent Advantage Normalization
MultiAgentGAE normalises advantages per-agent independently (reducing over batch + time but not the agent dimension). This prevents high-variance agents from dominating the gradient signal:
make_value_estimator()
EveryLossModule exposes make_value_estimator() for swapping the advantage estimator at runtime. The method accepts a ValueEstimators enum value and any keyword arguments the estimator constructor accepts:
Shifted and Compact Estimator Variants
Starting from TorchRL 0.13, all estimators supportshifted=True for a memory-efficient single-call path. In standard mode (shifted=False) the value network is called twice per update — once on the current observations and once on the next observations. With shifted=True, a single call is made over a fused [T + shifted_budget]-length sequence:
shifted_budget=2 when your rollout contains internal resets (truncated episodes that auto-reset mid-batch) so their true next-observations can be inserted without displacing other samples.
Choosing a Value Estimator
- GAE (Recommended)
- TDLambdaEstimator
- TD0Estimator
- VTrace
- MultiAgentGAE
The standard choice for on-policy methods. λ=0.95 is a robust default that
works well across most environments.