TorchRL transforms areDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
torch.nn.Module subclasses that intercept, modify, or augment the TensorDict flowing through an environment’s reset and step calls. Because each transform declares its own in_keys / out_keys and optional in_keys_inv / out_keys_inv, they compose cleanly and are replayable inside replay buffers without any additional bookkeeping. Wrap any EnvBase instance with TransformedEnv to attach a transform or a Compose pipeline, and the resulting object is itself a fully compliant EnvBase that can be further wrapped, vectorised, or compiled.
All transforms and
TransformedEnv are importable from torchrl.envs:Transform Base Class
Transform
Transform is the abstract base for every built-in and custom transform. It extends nn.Module and exposes four key hooks:
| Method | Called during | Purpose |
|---|---|---|
_call(tensordict) | step output | Modify observations / rewards in the forward pass. |
_inv_call(tensordict) | step input | Modify actions in the inverse pass (e.g., ActionScaling). |
_reset(tensordict, reset_td) | reset | Apply the transform to reset output. |
_apply_transform(tensor) | per tensor | Convenience hook when the same operation applies to each key individually. |
TransformedEnv
TransformedEnv is the primary wrapper that attaches one or more transforms to a base environment. It is itself an EnvBase, so it can be nested, vectorised, and compiled.
Compose
Compose chains multiple transforms into one and exposes the combined in_keys / out_keys. Passes data sequentially through each child transform.
Observation Transforms
ObservationNorm
Normalises observations with a configurable loc (mean) and scale (standard deviation). Supports in-place statistics collection via init_stats.
ToTensorImage
Converts a uint8 image tensor in HWC format (height × width × channels) to a float32 CHW tensor in the range [0, 1]. Requires torchvision.
GrayScale
Converts an RGB image (C × H × W) to grayscale by a weighted channel average. Requires torchvision.
Resize
Resizes pixel observations to a target spatial resolution. Requires torchvision.
CenterCrop
Centre-crops pixel observations to a fixed spatial size. Requires torchvision.
UnsqueezeTransform
Inserts a new dimension at position unsqueeze_dim in the selected tensors.
SqueezeTransform
Removes a size-1 dimension at position squeeze_dim.
FlattenObservation
Flattens a range of contiguous dimensions in an observation tensor into one.
CatFrames
Stacks the last N observations along a specified dimension to produce a frame-stacked observation. Essential for Atari-style pixel environments.
FrameSkipTransform
Repeats each action for frame_skip consecutive steps and returns the observation from the last step and the accumulated reward.
PermuteTransform
Permutes tensor dimensions. Useful to convert between channel layouts.
NextObservationDelta
Computes the element-wise difference between the next observation and the current observation, storing the delta under a configurable output key.
Reward Transforms
RewardScaling
Applies a linear transformation reward ← reward * scale + loc to the reward tensor.
RewardClipping
Clips rewards element-wise between clamp_min and clamp_max.
RewardSum
Accumulates the episode return in a running episode_reward key, resetting on done.
BinarizeReward
Maps reward to +1 (positive) or 0 (non-positive).
Reward2GoTransform
Computes the discounted reward-to-go for each step in a stored trajectory. Primarily used on replay-buffer data rather than live environment data.
Action Transforms
ActionScaling
Rescales actions from the policy’s output range to the environment’s action_spec domain. The inverse pass (applied before env.step) undoes the scaling.
FlattenAction
Flattens a multi-dimensional action tensor into a 1-D vector. Useful when a policy outputs a flat action that must be reshaped for the environment.
ActionDiscretizer
Discretises a continuous action space into a fixed grid of num_intervals bins per action dimension, converting a Bounded action spec into a Categorical one.
ActionMask
Filters the set of valid discrete actions using an action mask key present in the observation. Commonly used in environments with variable action legality.
Utility Transforms
StepCounter
Appends a "step_count" integer tensor to the output TensorDict and, optionally, a "truncated" flag when max_steps is reached.
DTypeCastTransform
Casts specified tensor keys from one dtype to another.
DoubleToFloat
Convenience wrapper around DTypeCastTransform that converts all float64 tensors to float32.
DeviceCastTransform
Moves tensors matching the selected keys to a specified device.
VecNorm
Running normalisation of observations and/or rewards using a shared exponential moving average of mean and variance. When used with EnvCreator and ParallelEnv, all worker processes share the same running statistics.
ExcludeTransform
Removes specified keys from the TensorDict. Useful to strip large pixel tensors after they have been processed.
InitTracker
Adds an "is_init" boolean flag to every step’s TensorDict, set to True on the first step after a reset. Required by some recurrent policies to detect episode boundaries.
TensorDictPrimer
Pre-populates the reset TensorDict with zero-filled entries for keys that are absent at reset but expected at step time (e.g., hidden states, info-dict keys).
RenameTransform
Renames one or more keys in the TensorDict, both in the forward and inverse passes.
SelectTransform
Keeps only the specified keys in the output TensorDict, discarding everything else.
TerminateTransform
Injects a custom termination condition — any callable over the TensorDict — as an additional terminated flag. Useful for combining open-loop playback with early stopping.
ExpandAs
Expands a source tensor to match the shape of a reference tensor. Handy for broadcasting scalar rewards to match batched observation shapes.
NoopResetEnv
Executes a random number of no-op (zero-action) steps after each reset to randomise the starting state, following the Atari training protocol.
BatchSizeTransform
Reshapes the batch_size of the environment. Useful when wrapping a single environment so it appears as a batch of size [1].
BurnInTransform
Runs a burn-in phase of burn_in steps at the beginning of each sequence without collecting data. Required by stateful models such as RNNs when sampling from a replay buffer.
Compose Multiple Transforms
- Sequential Pipeline
- Normalised MuJoCo
Quick-Reference Table
Transforms marked New in 0.13 were added in TorchRL 0.13 and may not be present in older releases.
| Transform | Category | Description |
|---|---|---|
ObservationNorm | Observation | Shift + scale observations with configurable loc/scale. |
ToTensorImage | Observation | uint8 HWC → float32 CHW in [0, 1]. |
GrayScale | Observation | RGB → single-channel grayscale. |
Resize | Observation | Resize pixel tensors to (w, h). |
CenterCrop | Observation | Centre-crop pixels to a fixed size. |
UnsqueezeTransform | Observation | Insert a new dimension. |
SqueezeTransform | Observation | Remove a size-1 dimension. |
FlattenObservation | Observation | Flatten a range of dimensions. |
CatFrames | Observation | Stack N consecutive frames. |
PermuteTransform | Observation | Permute tensor axes. |
NextObservationDelta | Observation | Next obs − current obs delta. (New in 0.13) |
FrameSkipTransform | Observation | Repeat action N steps. |
RewardScaling | Reward | Linear reward transform. |
RewardClipping | Reward | Clip rewards to [min, max]. |
RewardSum | Reward | Running episode return accumulator. |
BinarizeReward | Reward | Map reward to . |
Reward2GoTransform | Reward | Discounted return for offline data. |
ActionScaling | Action | Rescale policy outputs to action spec. |
FlattenAction | Action | Flatten multi-dimensional actions. |
ActionDiscretizer | Action | Continuous → discrete action grid. |
ActionMask | Action | Apply action legality mask. |
StepCounter | Utility | Count steps; optional truncation. |
DTypeCastTransform | Utility | Cast tensor dtypes. |
DoubleToFloat | Utility | float64 → float32 convenience. |
DeviceCastTransform | Utility | Move tensors to a device. |
VecNorm | Utility | Shared running normalisation. |
ExcludeTransform | Utility | Remove keys from TensorDict. |
InitTracker | Utility | Mark first post-reset step. |
TensorDictPrimer | Utility | Pre-fill absent keys at reset. |
RenameTransform | Utility | Rename TensorDict keys. |
TerminateTransform | Utility | Custom termination condition. (New in 0.13) |
ExpandAs | Utility | Broadcast tensor shape. (New in 0.13) |
NoopResetEnv | Utility | Random no-op steps after reset. |
BurnInTransform | Utility | Burn-in for stateful models. |
BatchSizeTransform | Utility | Reshape environment batch size. |