Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL transforms are torch.nn.Module subclasses that intercept, modify, or augment the TensorDict flowing through an environment’s reset and step calls. Because each transform declares its own in_keys / out_keys and optional in_keys_inv / out_keys_inv, they compose cleanly and are replayable inside replay buffers without any additional bookkeeping. Wrap any EnvBase instance with TransformedEnv to attach a transform or a Compose pipeline, and the resulting object is itself a fully compliant EnvBase that can be further wrapped, vectorised, or compiled.
All transforms and TransformedEnv are importable from torchrl.envs:
from torchrl.envs import TransformedEnv, Compose, ObservationNorm, RewardScaling

Transform Base Class

Transform

Transform is the abstract base for every built-in and custom transform. It extends nn.Module and exposes four key hooks:
MethodCalled duringPurpose
_call(tensordict)step outputModify observations / rewards in the forward pass.
_inv_call(tensordict)step inputModify actions in the inverse pass (e.g., ActionScaling).
_reset(tensordict, reset_td)resetApply the transform to reset output.
_apply_transform(tensor)per tensorConvenience hook when the same operation applies to each key individually.
from torchrl.envs import Transform
import torch

class AddConstant(Transform):
    def __init__(self, constant: float, in_keys, out_keys=None):
        super().__init__(in_keys=in_keys, out_keys=out_keys or in_keys)
        self.constant = constant

    def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
        return obs + self.constant

TransformedEnv

TransformedEnv is the primary wrapper that attaches one or more transforms to a base environment. It is itself an EnvBase, so it can be nested, vectorised, and compiled.
from torchrl.envs import GymEnv, TransformedEnv
from torchrl.envs.transforms import ObservationNorm, RewardScaling, Compose

env = TransformedEnv(
    GymEnv("HalfCheetah-v4"),
    Compose(
        ObservationNorm(in_keys=["observation"], loc=0.0, scale=1.0),
        RewardScaling(loc=0.0, scale=0.1),
    ),
)
env.check_env_specs()
td = env.rollout(200)

Compose

Compose chains multiple transforms into one and exposes the combined in_keys / out_keys. Passes data sequentially through each child transform.
from torchrl.envs.transforms import Compose, DoubleToFloat, StepCounter

t = Compose(DoubleToFloat(), StepCounter())
env = TransformedEnv(base_env, t)

Observation Transforms

ObservationNorm

Normalises observations with a configurable loc (mean) and scale (standard deviation). Supports in-place statistics collection via init_stats.
from torchrl.envs.transforms import ObservationNorm

t = ObservationNorm(in_keys=["observation"], loc=0.0, scale=1.0)

ToTensorImage

Converts a uint8 image tensor in HWC format (height × width × channels) to a float32 CHW tensor in the range [0, 1]. Requires torchvision.
from torchrl.envs.transforms import ToTensorImage
t = ToTensorImage(in_keys=["pixels"])

GrayScale

Converts an RGB image (C × H × W) to grayscale by a weighted channel average. Requires torchvision.
from torchrl.envs.transforms import GrayScale
t = GrayScale(in_keys=["pixels"])

Resize

Resizes pixel observations to a target spatial resolution. Requires torchvision.
from torchrl.envs.transforms import Resize
t = Resize(84, 84, in_keys=["pixels"])

CenterCrop

Centre-crops pixel observations to a fixed spatial size. Requires torchvision.
from torchrl.envs.transforms import CenterCrop
t = CenterCrop(84, in_keys=["pixels"])

UnsqueezeTransform

Inserts a new dimension at position unsqueeze_dim in the selected tensors.
from torchrl.envs.transforms import UnsqueezeTransform
t = UnsqueezeTransform(unsqueeze_dim=-1, in_keys=["observation"])

SqueezeTransform

Removes a size-1 dimension at position squeeze_dim.
from torchrl.envs.transforms import SqueezeTransform
t = SqueezeTransform(squeeze_dim=-1, in_keys=["observation"])

FlattenObservation

Flattens a range of contiguous dimensions in an observation tensor into one.
from torchrl.envs.transforms import FlattenObservation
t = FlattenObservation(first_dim=-3, last_dim=-1, in_keys=["pixels"])

CatFrames

Stacks the last N observations along a specified dimension to produce a frame-stacked observation. Essential for Atari-style pixel environments.
from torchrl.envs.transforms import CatFrames, ToTensorImage, GrayScale, Resize, Compose

t = Compose(
    ToTensorImage(in_keys=["pixels"]),
    GrayScale(in_keys=["pixels"]),
    Resize(84, 84, in_keys=["pixels"]),
    CatFrames(N=4, dim=-3, in_keys=["pixels"]),
)

FrameSkipTransform

Repeats each action for frame_skip consecutive steps and returns the observation from the last step and the accumulated reward.
from torchrl.envs.transforms import FrameSkipTransform
t = FrameSkipTransform(frame_skip=4)

PermuteTransform

Permutes tensor dimensions. Useful to convert between channel layouts.
from torchrl.envs.transforms import PermuteTransform
# HWC → CHW
t = PermuteTransform(dims=[2, 0, 1], in_keys=["pixels"])

NextObservationDelta

Computes the element-wise difference between the next observation and the current observation, storing the delta under a configurable output key.
from torchrl.envs.transforms import NextObservationDelta
t = NextObservationDelta(in_keys=["observation"])

Reward Transforms

RewardScaling

Applies a linear transformation reward ← reward * scale + loc to the reward tensor.
from torchrl.envs.transforms import RewardScaling
t = RewardScaling(loc=0.0, scale=0.01)

RewardClipping

Clips rewards element-wise between clamp_min and clamp_max.
from torchrl.envs.transforms import RewardClipping
t = RewardClipping(clamp_min=-1.0, clamp_max=1.0)

RewardSum

Accumulates the episode return in a running episode_reward key, resetting on done.
from torchrl.envs.transforms import RewardSum
t = RewardSum()
# Adds "episode_reward" to the output TensorDict

BinarizeReward

Maps reward to +1 (positive) or 0 (non-positive).

Reward2GoTransform

Computes the discounted reward-to-go for each step in a stored trajectory. Primarily used on replay-buffer data rather than live environment data.

Action Transforms

ActionScaling

Rescales actions from the policy’s output range to the environment’s action_spec domain. The inverse pass (applied before env.step) undoes the scaling.
from torchrl.envs.transforms import ActionScaling
t = ActionScaling(in_keys=["action"])

FlattenAction

Flattens a multi-dimensional action tensor into a 1-D vector. Useful when a policy outputs a flat action that must be reshaped for the environment.
from torchrl.envs.transforms import FlattenAction
t = FlattenAction(first_dim=-2, last_dim=-1)

ActionDiscretizer

Discretises a continuous action space into a fixed grid of num_intervals bins per action dimension, converting a Bounded action spec into a Categorical one.

ActionMask

Filters the set of valid discrete actions using an action mask key present in the observation. Commonly used in environments with variable action legality.

Utility Transforms

StepCounter

Appends a "step_count" integer tensor to the output TensorDict and, optionally, a "truncated" flag when max_steps is reached.
from torchrl.envs.transforms import StepCounter
t = StepCounter(max_steps=1000)

DTypeCastTransform

Casts specified tensor keys from one dtype to another.
from torchrl.envs.transforms import DTypeCastTransform
import torch
t = DTypeCastTransform(dtype_in=torch.float64, dtype_out=torch.float32)

DoubleToFloat

Convenience wrapper around DTypeCastTransform that converts all float64 tensors to float32.
from torchrl.envs.transforms import DoubleToFloat
t = DoubleToFloat()

DeviceCastTransform

Moves tensors matching the selected keys to a specified device.
from torchrl.envs.transforms import DeviceCastTransform
t = DeviceCastTransform(device="cuda:0")

VecNorm

VecNorm is deprecated in favour of VecNormV2. Use VecNormV2 for new code. VecNorm will be removed in a future release.
Running normalisation of observations and/or rewards using a shared exponential moving average of mean and variance. When used with EnvCreator and ParallelEnv, all worker processes share the same running statistics.
from torchrl.envs.transforms import VecNorm
t = VecNorm(in_keys=["observation"], decay=0.9999)

ExcludeTransform

Removes specified keys from the TensorDict. Useful to strip large pixel tensors after they have been processed.
from torchrl.envs.transforms import ExcludeTransform
t = ExcludeTransform("pixels")

InitTracker

Adds an "is_init" boolean flag to every step’s TensorDict, set to True on the first step after a reset. Required by some recurrent policies to detect episode boundaries.

TensorDictPrimer

Pre-populates the reset TensorDict with zero-filled entries for keys that are absent at reset but expected at step time (e.g., hidden states, info-dict keys).
from torchrl.envs.transforms import TensorDictPrimer
from torchrl.data import Unbounded

t = TensorDictPrimer({"hidden": Unbounded(shape=(256,))})

RenameTransform

Renames one or more keys in the TensorDict, both in the forward and inverse passes.
from torchrl.envs.transforms import RenameTransform
t = RenameTransform(in_keys=["obs"], out_keys=["observation"])

SelectTransform

Keeps only the specified keys in the output TensorDict, discarding everything else.

TerminateTransform

Injects a custom termination condition — any callable over the TensorDict — as an additional terminated flag. Useful for combining open-loop playback with early stopping.

ExpandAs

Expands a source tensor to match the shape of a reference tensor. Handy for broadcasting scalar rewards to match batched observation shapes.

NoopResetEnv

Executes a random number of no-op (zero-action) steps after each reset to randomise the starting state, following the Atari training protocol.
from torchrl.envs.transforms import NoopResetEnv
t = NoopResetEnv(noops=30, random=True)

BatchSizeTransform

Reshapes the batch_size of the environment. Useful when wrapping a single environment so it appears as a batch of size [1].

BurnInTransform

Runs a burn-in phase of burn_in steps at the beginning of each sequence without collecting data. Required by stateful models such as RNNs when sampling from a replay buffer.

Compose Multiple Transforms

from torchrl.envs import GymEnv, TransformedEnv
from torchrl.envs.transforms import (
    Compose,
    ToTensorImage,
    GrayScale,
    Resize,
    CatFrames,
    DoubleToFloat,
    StepCounter,
    RewardClipping,
)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"]),
        GrayScale(in_keys=["pixels"]),
        Resize(84, 84, in_keys=["pixels"]),
        CatFrames(N=4, dim=-3, in_keys=["pixels"]),
        DoubleToFloat(),
        StepCounter(max_steps=108_000),
        RewardClipping(clamp_min=-1.0, clamp_max=1.0),
    ),
)
env.check_env_specs()

Quick-Reference Table

Transforms marked New in 0.13 were added in TorchRL 0.13 and may not be present in older releases.
TransformCategoryDescription
ObservationNormObservationShift + scale observations with configurable loc/scale.
ToTensorImageObservationuint8 HWC → float32 CHW in [0, 1].
GrayScaleObservationRGB → single-channel grayscale.
ResizeObservationResize pixel tensors to (w, h).
CenterCropObservationCentre-crop pixels to a fixed size.
UnsqueezeTransformObservationInsert a new dimension.
SqueezeTransformObservationRemove a size-1 dimension.
FlattenObservationObservationFlatten a range of dimensions.
CatFramesObservationStack N consecutive frames.
PermuteTransformObservationPermute tensor axes.
NextObservationDeltaObservationNext obs − current obs delta. (New in 0.13)
FrameSkipTransformObservationRepeat action N steps.
RewardScalingRewardLinear reward transform.
RewardClippingRewardClip rewards to [min, max].
RewardSumRewardRunning episode return accumulator.
BinarizeRewardRewardMap reward to .
Reward2GoTransformRewardDiscounted return for offline data.
ActionScalingActionRescale policy outputs to action spec.
FlattenActionActionFlatten multi-dimensional actions.
ActionDiscretizerActionContinuous → discrete action grid.
ActionMaskActionApply action legality mask.
StepCounterUtilityCount steps; optional truncation.
DTypeCastTransformUtilityCast tensor dtypes.
DoubleToFloatUtilityfloat64 → float32 convenience.
DeviceCastTransformUtilityMove tensors to a device.
VecNormUtilityShared running normalisation.
ExcludeTransformUtilityRemove keys from TensorDict.
InitTrackerUtilityMark first post-reset step.
TensorDictPrimerUtilityPre-fill absent keys at reset.
RenameTransformUtilityRename TensorDict keys.
TerminateTransformUtilityCustom termination condition. (New in 0.13)
ExpandAsUtilityBroadcast tensor shape. (New in 0.13)
NoopResetEnvUtilityRandom no-op steps after reset.
BurnInTransformUtilityBurn-in for stateful models.
BatchSizeTransformUtilityReshape environment batch size.

Build docs developers (and LLMs) love