Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Standard PyTorch distributions work well for supervised learning but need small adaptations for reinforcement learning. Continuous control policies often require actions bounded to a fixed range — applying a tanh transform is the standard solution, but naïve log_prob computation through tanh is numerically unstable near ±1. Discrete policies in environments with dynamic invalid actions (e.g. board games, sequential decision trees) need a masked distribution that assigns zero probability to invalid choices. TorchRL’s distributions address both concerns with numerically stable implementations and clean integration with ProbabilisticActor.

Why Specialized Distributions for RL

Three issues motivate TorchRL’s custom distributions:
  1. Bounded continuous actions. Most actuators have torque or velocity limits. Squashing with tanh enforces the bound but makes log_prob ill-conditioned for values near the boundary unless the inverse transform is numerically stabilized.
  2. Location scaling. Raw network outputs can grow very large, causing tanh saturation and vanishing gradients. Location scaling — loc = tanh(loc / scale) * scale — keeps the pre-tanh value in a well-behaved range.
  3. Action masking. In partially-observable or structured-action environments, some actions are invalid at each step. Sampling from them wastes rollout budget; computing log_prob over them corrupts gradients.

NormalParamExtractor

NormalParamExtractor (re-exported from tensordict.nn) splits the output of a network into (loc, scale) halves and applies a softplus-based mapping to ensure scale > 0. It is the recommended way to parameterize Normal-family distributions.
from torch import nn
from tensordict.nn import NormalParamExtractor

# A network with 8 output features becomes (loc[4], scale[4])
backbone = nn.Sequential(nn.Linear(4, 8), NormalParamExtractor())

import torch
loc, scale = backbone(torch.randn(3, 4))
print(loc.shape, scale.shape)  # torch.Size([3, 4]) torch.Size([3, 4])
NormalParamExtractor is also available as torchrl.modules.NormalParamExtractor. The legacy NormalParamWrapper class has been removed; if you see it in old code, replace it with NormalParamExtractor.

TanhNormal

TanhNormal is a squashed Gaussian for bounded continuous action spaces. It constructs a Normal(loc, scale) base distribution and passes samples through a TanhTransform (optionally composed with an affine rescaling to [low, high]). The resulting sample lies strictly within (low, high), making it suitable for direct use with bounded action specs without a spec.project() call.
loc
torch.Tensor
required
Location parameter of the underlying Normal distribution.
scale
torch.Tensor | float | callable
required
Scale parameter. Accepts a tensor, a float, or a callable (e.g. torch.ones_like) that takes loc and returns the scale tensor. Using a callable avoids device transfers and prevents graph breaks under torch.compile.
upscale
float | torch.Tensor
default:"5.0"
Factor used in location scaling: loc = tanh(loc / upscale) * upscale. Only applied when tanh_loc=True.
low
float | torch.Tensor
default:"-1.0"
Lower bound of the action range. Combined with high to define the affine rescaling applied after tanh.
high
float | torch.Tensor
default:"1.0"
Upper bound of the action range. Must be strictly greater than low.
event_dims
int
Number of trailing dimensions over which log_prob is summed. Defaults to min(1, loc.ndim). Set to 0 to get a per-element log-probability.
tanh_loc
bool
default:"False"
When True, applies location scaling (loc = tanh(loc/upscale)*upscale) before constructing the distribution.
safe_tanh
bool
default:"True"
When True, uses a numerically stable SafeTanhTransform that clips inputs to avoid atanh overflow. Set to False for torch.compile compatibility (native TanhTransform is used instead).

Key Properties

PropertyDescription
deterministic_sampleThe deterministic action: tanh(loc) mapped to [low, high].
low, highAccess the action bounds (emits a deprecation warning; prefer passing at construction).
get_mode()Numerically estimates the mode using Adam (200 steps). Expensive; avoid in hot paths.
TanhNormal does not have closed-form mean or mode properties. Calling dist.mean raises NotImplementedError; calling dist.mode raises RuntimeError. Use dist.deterministic_sample for a deterministic action during evaluation, or dist.get_mode() for an approximate mode.

Example: Stochastic Actor with TanhNormal

import torch
from torch import nn
from functools import partial
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, NormalParamExtractor
from torchrl.data import Bounded
from torchrl.modules import ProbabilisticActor, TanhNormal

# Backbone: observation (4-D) → (loc, scale) each of size 4
backbone = nn.Sequential(nn.Linear(4, 8), NormalParamExtractor())
td_backbone = TensorDictModule(
    backbone, in_keys=["observation"], out_keys=["loc", "scale"]
)

# ProbabilisticActor with TanhNormal, actions bounded to [-1, 1]
actor = ProbabilisticActor(
    module=td_backbone,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    spec=Bounded(shape=torch.Size([4]), low=-1, high=1),
    distribution_class=TanhNormal,
    distribution_kwargs={"low": -1.0, "high": 1.0},
    return_log_prob=True,
)

td = TensorDict({"observation": torch.randn(3, 4)}, [3])
td = actor(td)
print(td["action"].shape)          # torch.Size([3, 4]) — within [-1, 1]
print(td["sample_log_prob"].shape) # torch.Size([3])

# Callable scale avoids device transfers and compile graph breaks:
from torchrl.modules import TanhNormal as TN
dist = TN(
    loc=torch.zeros(3, 4),
    scale=partial(torch.full_like, fill_value=0.1),
)
sample = dist.rsample()
print(sample.shape)  # torch.Size([3, 4])

TruncatedNormal

TruncatedNormal implements a truncated Gaussian distribution clamped to [low, high] using the exact truncated-Normal density (not a rejection sampler). Like TanhNormal, it supports location scaling.
from torchrl.modules import TruncatedNormal
import torch

dist = TruncatedNormal(
    loc=torch.zeros(3, 4),
    scale=torch.ones(3, 4) * 0.5,
    low=-1.0,
    high=1.0,
    tanh_loc=False,
)
sample = dist.rsample()
print(sample.shape)  # torch.Size([3, 4]), all in [-1, 1]
print(dist.log_prob(sample).shape)  # torch.Size([3])

IndependentNormal

IndependentNormal wraps torch.distributions.Normal with location scaling and optional factorization over event dimensions. It is a simpler alternative to TanhNormal when actions are unbounded.
from torchrl.modules import IndependentNormal
import torch

dist = IndependentNormal(
    loc=torch.zeros(3, 4),
    scale=torch.ones(3, 4),
    upscale=5.0,
    tanh_loc=False,
)
sample = dist.sample()
print(sample.shape)  # torch.Size([3, 4])
print(dist.log_prob(sample).shape)  # torch.Size([3])

MaskedCategorical

MaskedCategorical extends torch.distributions.Categorical with a boolean mask that sets the log-probability of invalid actions to -inf, then re-normalizes the remaining probabilities. This ensures the sampled action is always valid and the policy gradient does not receive signal from impossible actions.
logits
torch.Tensor
Unnormalized log-probabilities for each action. Exclusive with probs.
probs
torch.Tensor
Action probabilities. Invalid actions (where mask=False) are zeroed and the distribution is renormalized. Exclusive with logits.
mask
torch.Tensor
required
Boolean tensor of the same shape as logits/probs. True entries are valid actions; False entries are masked to -inf. Exclusive with indices.
indices
torch.Tensor
Sparse integer index tensor specifying valid actions. Alternative to mask for environments where only a small subset of actions is valid. Exclusive with mask.
neg_inf
float
default:"-inf"
The log-probability assigned to masked-out (invalid) actions. Defaults to float("-inf") to give zero probability; use a large negative finite value (e.g. -1e8) if downstream code cannot handle -inf.
use_cross_entropy
bool
default:"True"
When True, uses F.cross_entropy for a faster log_prob computation.

Example

import torch
from torchrl.modules import MaskedCategorical

torch.manual_seed(0)
logits = torch.randn(4) / 100  # near-uniform
mask = torch.tensor([True, False, True, True])  # action 1 is invalid

dist = MaskedCategorical(logits=logits, mask=mask)
sample = dist.sample((10,))
print(sample)                        # no 1s in the sample
print(dist.log_prob(torch.tensor(1))) # -inf
print(dist.entropy())                # entropy over 3 valid actions
Combine MaskedCategorical with ProbabilisticActor by passing the mask key in in_keys and using a custom distribution_kwargs or a wrapper module that builds the distribution with mask from the TensorDict.

MaskedOneHotCategorical

MaskedOneHotCategorical is the one-hot encoding variant of MaskedCategorical. Samples are returned as one-hot vectors rather than integer indices. Use it with environments whose action spec is OneHot.
from torchrl.modules import MaskedOneHotCategorical
import torch

logits = torch.randn(4)
mask = torch.tensor([True, True, False, True])
dist = MaskedOneHotCategorical(logits=logits, mask=mask)
sample = dist.sample()
print(sample)  # one-hot vector, never [0, 0, 1, 0]

OneHotCategorical

OneHotCategorical is a torch.distributions.Categorical subclass that returns one-hot encoded samples. Useful for discrete action spaces where downstream modules expect a binary action vector rather than a scalar index.
from torchrl.modules import OneHotCategorical
import torch

dist = OneHotCategorical(logits=torch.randn(4))
sample = dist.sample()
print(sample.shape)  # torch.Size([4])
print(sample.sum())  # 1.0

Delta

Delta (re-exported from tensordict.nn.distributions) is a deterministic distribution concentrated at a single point. Its log_prob returns 0 for the exact value and -inf for any other. ProbabilisticActor uses Delta as the default distribution_class, making it suitable for deterministic policies that still go through the probabilistic interface.
from torchrl.modules import Delta
import torch

dist = Delta(param=torch.tensor([0.5, -0.3]))
print(dist.sample())             # tensor([0.5, -0.3])
print(dist.log_prob(dist.param)) # tensor(0.)

TanhDelta

TanhDelta applies a TanhTransform (and optional affine rescaling) on top of a Delta distribution. It is the natural distribution class for deterministic policies with bounded actions.
from torchrl.modules import TanhDelta
import torch

dist = TanhDelta(param=torch.zeros(4), low=-1.0, high=1.0)
sample = dist.sample()
print(sample.shape)              # torch.Size([4])
print((sample >= -1).all())      # True
print((sample <= 1).all())       # True

Ordinal and OneHotOrdinal

Ordinal treats a discrete action space as an ordinal variable where adjacent actions are “closer” to one another than distant ones. OneHotOrdinal is the one-hot encoded variant. Both inherit from torch.distributions.Categorical.

LLMMaskedCategorical

LLMMaskedCategorical is a specialized masked categorical distribution for language model policies. It handles large vocabulary sizes and sparse valid-token masks efficiently, making it suitable for token-level action selection in LLM-based agents.

distributions_maps

distributions_maps is a dictionary mapping lowercased class-repr strings to distribution classes. The keys are produced by str(dist_class).lower(), which yields the full qualified class path enclosed in angle brackets. Look up a class by passing its class object directly as a key or by building the key via str(MyDistClass).lower():
from torchrl.modules import distributions_maps, TanhNormal

# Look up by constructing the key the same way the dict was built
key = str(TanhNormal).lower()  # e.g. "<class 'torchrl.modules.distributions.continuous.tanhnormal'>"
dist_class = distributions_maps[key]

# Or look up by iterating values
tanhnormal_class = next(v for v in distributions_maps.values() if v is TanhNormal)

Choosing a Distribution

TanhNormal — standard for PPO/SAC with continuous bounded actions. Numerically stable log_prob through SafeTanhTransform. Use distribution_kwargs={"low": lo, "high": hi} to set the action range.

Build docs developers (and LLMs) love