Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Every environment in TorchRL derives from EnvBase, a PyTorch nn.Module subclass that communicates via TensorDict rather than raw tuples. Observations, actions, rewards, and done signals all live under named keys, and each field is described by a spec — a typed, bounded, or discrete descriptor that validates shapes, dtypes, and value ranges before a long training job starts. The same API works whether you are running a single-instance Pendulum environment, a 64-worker MuJoCo farm, or a custom game-engine simulator.

EnvBase: the core API

EnvBase exposes four key methods. You interact with them the same way regardless of which environment library is underneath.
from torchrl.envs import PendulumEnv

env = PendulumEnv()
td = env.reset()
# td["observation"] -> initial observation tensor
print(td)
# TensorDict(
#     fields={observation: Tensor(shape=torch.Size([3]), ...)},
#     batch_size=torch.Size([]),
#     device=cpu)
reset() returns a TensorDict populated with initial observations. The done, reward, and action keys are absent — they have not been produced yet.

Specs: describing the data contract

Specs tell TorchRL (and your code) exactly what to expect from an environment before a single step is taken. They are used to pre-allocate storage, validate outputs, and initialise policy networks lazily.
from torchrl.data import Bounded, Categorical, Composite, Unbounded

# Continuous action in [-1, 1]^2
action_spec = Bounded(low=-1.0, high=1.0, shape=(2,), dtype=torch.float32)

# Discrete action: one of 4 choices
discrete_spec = Categorical(n=4)

# Composite groups multiple specs under named keys.
obs_spec = Composite(
    observation=Unbounded(shape=(8,), dtype=torch.float32),
    pixels=Bounded(low=0, high=255, shape=(3, 84, 84), dtype=torch.uint8),
)
Every EnvBase subclass exposes four spec properties:
PropertyWhat it describes
observation_specAll observation fields (maps to full_observation_spec)
action_specThe action field(s) the environment expects
reward_specThe reward scalar or vector
done_specDone, terminated, and truncated flags
full_observation_spec, full_action_spec, full_reward_spec, and full_done_spec are the canonical Composite specs. The short-hand properties observation_spec, action_spec, reward_spec, and done_spec link to the leaf spec inside the composite for single-key environments.

TransformedEnv and the transform pipeline

TransformedEnv wraps any EnvBase with a stack of Transform objects. Transforms can preprocess observations, post-process rewards, convert dtypes, or inject priors. They are applied in order on every step() and reset() call, and each one participates in the spec system — adding or modifying specs so that downstream components always see the correct shapes.
from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, TransformedEnv
from torchrl.envs.libs.gym import GymEnv

base_env = GymEnv("HalfCheetah-v4", device="cuda:0")
env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
    ),
)
Call env.transform to inspect or modify the transform stack. Individual transforms can be inserted or removed without re-wrapping the base environment.

Common transforms

TorchRL ships a large library of built-in transforms. Here are the most commonly used:

Observation transforms

ObservationNorm — running-mean normalisation
CatFrames — frame stacking
GrayScale, Resize, CenterCrop, ToTensorImage — pixel pipelines
FlattenObservation — flatten spatial dims

Reward transforms

RewardScaling — multiply/shift reward
RewardClipping — clip to a range
RewardSum — cumulative reward tracking
BinarizeReward — convert to

Action transforms

ActionScaling — map to a different range
ActionDiscretizer — discretize continuous actions
FlattenAction — flatten multi-head actions
ActionMask — mask out illegal actions

Episode & timing

StepCounter — count steps per episode
TrajCounter — count completed trajectories
FrameSkipTransform — repeat actions
AutoResetTransform — auto-reset on done

Composing transforms

from torchrl.envs import (
    CatFrames,
    Compose,
    GrayScale,
    ObservationNorm,
    Resize,
    RewardClipping,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv

env = TransformedEnv(
    GymEnv("ALE/Pong-v5"),
    Compose(
        ToTensorImage(),            # [H, W, C] -> [C, H, W] float
        GrayScale(),                # [C, H, W] -> [1, H, W]
        Resize(84, 84),             # resize to 84x84
        CatFrames(N=4, dim=-3),     # stack last 4 frames
        ObservationNorm(in_keys=["pixels"], loc=0.0, scale=255.0),
        RewardClipping(clamp_min=-1, clamp_max=1),
        StepCounter(max_steps=27_000),
    ),
)
check_env_specs(env)

PyTorch-native custom environments

To create a custom environment, subclass EnvBase and implement three methods.
import torch
from tensordict import TensorDict
from torchrl.data import Bounded, Composite, Unbounded
from torchrl.envs import EnvBase


class MyEnv(EnvBase):
    def __init__(self, device="cpu"):
        super().__init__(device=device)
        self.observation_spec = Composite(
            observation=Unbounded(shape=(4,), dtype=torch.float32),
            device=device,
        )
        self.action_spec = Bounded(
            low=-1.0, high=1.0, shape=(1,), dtype=torch.float32, device=device
        )
        self.reward_spec = Unbounded(shape=(1,), dtype=torch.float32, device=device)

    def _reset(self, tensordict):
        return TensorDict(
            {"observation": torch.zeros(4, device=self.device)},
            batch_size=self.batch_size,
        )

    def _step(self, tensordict):
        action = tensordict["action"]
        obs = torch.randn(4, device=self.device)
        reward = -action.pow(2).sum(dim=-1, keepdim=True)
        done = torch.zeros(1, dtype=torch.bool, device=self.device)
        return TensorDict(
            {"observation": obs, "reward": reward, "done": done, "terminated": done},
            batch_size=self.batch_size,
        )

    def _set_seed(self, seed):
        torch.manual_seed(seed)
Always call check_env_specs(env) after implementing a custom environment. Mismatches between what _step() returns and what the specs declare are one of the most common sources of silent training bugs.

Wrapping third-party libraries

TorchRL includes wrappers for a wide range of RL environment libraries. All wrappers follow the same EnvBase API.
from torchrl.envs.libs.gym import GymEnv

env = GymEnv("HalfCheetah-v4", device="cpu")
# Use set_gym_backend() to toggle between gymnasium and gym.

Vectorized environments

TorchRL provides two vectorized wrappers that run multiple copies of an environment simultaneously.

SerialEnv

SerialEnv runs N environments sequentially in a single process. Useful for testing or when environment stepping is cheap. The batch dimension of all returned TensorDicts gains a leading [N] axis.
from torchrl.envs import SerialEnv

env = SerialEnv(4, lambda: GymEnv("Pendulum-v1"))
td = env.reset()
print(td.batch_size)  # torch.Size([4])

ParallelEnv

ParallelEnv runs N environments in separate worker processes. It exposes the same API as SerialEnv but uses multiprocessing to execute environment steps in parallel, which is valuable when simulation is the throughput bottleneck.
from torchrl.envs import ParallelEnv
from torchrl.envs.libs.gym import GymEnv

env = ParallelEnv(
    num_workers=8,
    create_env_fn=lambda: GymEnv("HalfCheetah-v4"),
)
td = env.reset()
print(td.batch_size)  # torch.Size([8])
rollout = env.rollout(max_steps=1000, policy=policy)
print(rollout.batch_size)  # torch.Size([8, 1000])
env.close()
ParallelEnv serializes the create_env_fn callable and sends it to worker processes. Lambdas that close over unpicklable objects (GPU tensors, open file handles) should be replaced with a picklable callable or use EnvCreator.

AsyncEnvPool

For even higher throughput, AsyncEnvPool overlaps environment stepping and policy inference using threads or separate processes.
from torchrl.envs import AsyncEnvPool

pool = AsyncEnvPool(
    [lambda: GymEnv("HalfCheetah-v4") for _ in range(16)],
)

Model-based environments

TorchRL includes ModelBasedEnvBase and DreamerEnv / WorldModelEnv for model-based RL workflows where the environment is a learned neural network. They share the same EnvBase API so policies and collectors work without changes.
from torchrl.envs import GymEnv, WorldModelEnv
from torchrl.modules import WorldModel

# WorldModelEnv wraps a WorldModel and a reference env as an EnvBase.
# base_env is only used for its specs (action, reward, done) — it is not stepped.
base_env = GymEnv("Pendulum-v1")
imagined_env = WorldModelEnv(
    world_model=world_model,   # a torchrl.modules.WorldModel instance
    base_env=base_env,         # reference env for specs
    batch_size=[4],
)
rollout = imagined_env.rollout(max_steps=15, policy=actor)

Build docs developers (and LLMs) love