Standard PyTorch distributions work well for supervised learning but need small adaptations for reinforcement learning. Continuous control policies often require actions bounded to a fixed range — applying aDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
tanh transform is the standard solution, but naïve log_prob computation through tanh is numerically unstable near ±1. Discrete policies in environments with dynamic invalid actions (e.g. board games, sequential decision trees) need a masked distribution that assigns zero probability to invalid choices. TorchRL’s distributions address both concerns with numerically stable implementations and clean integration with ProbabilisticActor.
Why Specialized Distributions for RL
Three issues motivate TorchRL’s custom distributions:- Bounded continuous actions. Most actuators have torque or velocity limits. Squashing with
tanhenforces the bound but makeslog_probill-conditioned for values near the boundary unless the inverse transform is numerically stabilized. - Location scaling. Raw network outputs can grow very large, causing
tanhsaturation and vanishing gradients. Location scaling —loc = tanh(loc / scale) * scale— keeps the pre-tanh value in a well-behaved range. - Action masking. In partially-observable or structured-action environments, some actions are invalid at each step. Sampling from them wastes rollout budget; computing
log_probover them corrupts gradients.
NormalParamExtractor
NormalParamExtractor (re-exported from tensordict.nn) splits the output of a network into (loc, scale) halves and applies a softplus-based mapping to ensure scale > 0. It is the recommended way to parameterize Normal-family distributions.
NormalParamExtractor is also available as torchrl.modules.NormalParamExtractor.
The legacy NormalParamWrapper class has been removed; if you see it in old
code, replace it with NormalParamExtractor.TanhNormal
TanhNormal is a squashed Gaussian for bounded continuous action spaces. It constructs a Normal(loc, scale) base distribution and passes samples through a TanhTransform (optionally composed with an affine rescaling to [low, high]).
The resulting sample lies strictly within (low, high), making it suitable for direct use with bounded action specs without a spec.project() call.
Location parameter of the underlying Normal distribution.
Scale parameter. Accepts a tensor, a float, or a callable
(e.g.
torch.ones_like) that takes loc and returns the scale tensor.
Using a callable avoids device transfers and prevents graph breaks under
torch.compile.Factor used in location scaling:
loc = tanh(loc / upscale) * upscale.
Only applied when tanh_loc=True.Lower bound of the action range. Combined with
high to define the
affine rescaling applied after tanh.Upper bound of the action range. Must be strictly greater than
low.Number of trailing dimensions over which
log_prob is summed. Defaults
to min(1, loc.ndim). Set to 0 to get a per-element log-probability.When
True, applies location scaling (loc = tanh(loc/upscale)*upscale)
before constructing the distribution.When
True, uses a numerically stable SafeTanhTransform that clips
inputs to avoid atanh overflow. Set to False for torch.compile
compatibility (native TanhTransform is used instead).Key Properties
| Property | Description |
|---|---|
deterministic_sample | The deterministic action: tanh(loc) mapped to [low, high]. |
low, high | Access the action bounds (emits a deprecation warning; prefer passing at construction). |
get_mode() | Numerically estimates the mode using Adam (200 steps). Expensive; avoid in hot paths. |
Example: Stochastic Actor with TanhNormal
TruncatedNormal
TruncatedNormal implements a truncated Gaussian distribution clamped to [low, high] using the exact truncated-Normal density (not a rejection sampler). Like TanhNormal, it supports location scaling.
IndependentNormal
IndependentNormal wraps torch.distributions.Normal with location scaling and optional factorization over event dimensions. It is a simpler alternative to TanhNormal when actions are unbounded.
MaskedCategorical
MaskedCategorical extends torch.distributions.Categorical with a boolean mask that sets the log-probability of invalid actions to -inf, then re-normalizes the remaining probabilities. This ensures the sampled action is always valid and the policy gradient does not receive signal from impossible actions.
Unnormalized log-probabilities for each action. Exclusive with
probs.Action probabilities. Invalid actions (where
mask=False) are zeroed and
the distribution is renormalized. Exclusive with logits.Boolean tensor of the same shape as
logits/probs. True entries are
valid actions; False entries are masked to -inf. Exclusive with
indices.Sparse integer index tensor specifying valid actions. Alternative to
mask for environments where only a small subset of actions is valid.
Exclusive with mask.The log-probability assigned to masked-out (invalid) actions. Defaults to
float("-inf") to give zero probability; use a large negative finite value
(e.g. -1e8) if downstream code cannot handle -inf.When
True, uses F.cross_entropy for a faster log_prob computation.Example
MaskedOneHotCategorical
MaskedOneHotCategorical is the one-hot encoding variant of MaskedCategorical. Samples are returned as one-hot vectors rather than integer indices. Use it with environments whose action spec is OneHot.
OneHotCategorical
OneHotCategorical is a torch.distributions.Categorical subclass that returns one-hot encoded samples. Useful for discrete action spaces where downstream modules expect a binary action vector rather than a scalar index.
Delta
Delta (re-exported from tensordict.nn.distributions) is a deterministic distribution concentrated at a single point. Its log_prob returns 0 for the exact value and -inf for any other. ProbabilisticActor uses Delta as the default distribution_class, making it suitable for deterministic policies that still go through the probabilistic interface.
TanhDelta
TanhDelta applies a TanhTransform (and optional affine rescaling) on top of a Delta distribution. It is the natural distribution class for deterministic policies with bounded actions.
Ordinal and OneHotOrdinal
Ordinal treats a discrete action space as an ordinal variable where adjacent actions are “closer” to one another than distant ones. OneHotOrdinal is the one-hot encoded variant. Both inherit from torch.distributions.Categorical.
LLMMaskedCategorical
LLMMaskedCategorical is a specialized masked categorical distribution for language model policies. It handles large vocabulary sizes and sparse valid-token masks efficiently, making it suitable for token-level action selection in LLM-based agents.
distributions_maps
distributions_maps is a dictionary mapping lowercased class-repr strings to distribution classes. The keys are produced by str(dist_class).lower(), which yields the full qualified class path enclosed in angle brackets. Look up a class by passing its class object directly as a key or by building the key via str(MyDistClass).lower():
Choosing a Distribution
- Bounded Continuous
- Unbounded Continuous
- Truncated Continuous
- Discrete (unmasked)
- Discrete (masked)
- Deterministic
TanhNormal — standard for PPO/SAC with continuous bounded actions.
Numerically stable log_prob through SafeTanhTransform. Use
distribution_kwargs={"low": lo, "high": hi} to set the action range.