Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Every environment in TorchRL—whether a simple grid world, a physics simulator, or a third-party library—implements the same EnvBase contract. The contract boils down to three abstract methods (_reset, _step, _set_seed) and a set of TensorSpec objects that describe the shapes and dtypes of every tensor the environment reads and writes. Once those are in place, the environment gains rollouts, batching, transforms, and collector support for free. This tutorial builds a custom environment from scratch, then shows how to wrap an existing Gymnasium environment and add preprocessing transforms.
1
Understand the EnvBase contract
2
EnvBase (from torchrl.envs) is a torch.nn.Module subclass that must implement:
3
MethodSignatureResponsibility_reset(tensordict) -> tensordictReturn the initial observation (must not mutate input)_step(tensordict) -> tensordictApply action, return next obs / reward / done (output only)_set_seed(seed: int) -> NoneConfigure the environment’s RNG
4
The specs you assign in __init__ define everything about the TensorDicts the environment produces:
5
from torchrl.data import BoundedContinuous, Categorical, Composite, Unbounded
6
Composite, BoundedContinuous, and Categorical are all in torchrl.data. The older aliases CompositeSpec, BoundedTensorSpec, and DiscreteTensorSpec still work but the shorter names are preferred in current code.
7
Implement a minimal custom environment
8
The example below implements a 1-D corridor: the agent controls a continuous force along a line and must reach position 0 from a random starting point. It demonstrates all required methods and spec assignments.
9
from __future__ import annotations

import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.data import BoundedContinuous, Composite, Unbounded
from torchrl.envs import EnvBase


class CorridorEnv(EnvBase):
    """1-D corridor: move a particle to position 0 with continuous force."""

    # Maximum position magnitude
    MAX_POS: float = 5.0

    def __init__(self, device=None, batch_size=()):
        super().__init__(device=device, batch_size=batch_size)

        # --- observation spec: position and velocity ---
        self.observation_spec = Composite(
            observation=BoundedContinuous(
                low=torch.tensor([-self.MAX_POS, -2.0]),
                high=torch.tensor([self.MAX_POS, 2.0]),
                shape=(*batch_size, 2),
                device=device,
            ),
            shape=batch_size,
        )

        # --- action spec: scalar force in [-1, 1] ---
        self.action_spec = BoundedContinuous(
            low=-1.0,
            high=1.0,
            shape=(*batch_size, 1),
            device=device,
        )

        # --- reward spec: scalar reward ---
        self.reward_spec = Unbounded(
            shape=(*batch_size, 1),
            device=device,
        )

        # Internal state (position and velocity)
        self._pos = torch.zeros(*batch_size, 1, device=device)
        self._vel = torch.zeros(*batch_size, 1, device=device)

    # ------------------------------------------------------------------
    # Required abstract methods
    # ------------------------------------------------------------------

    def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
        """Return a fresh initial observation."""
        batch_size = tensordict.batch_size if tensordict is not None else self.batch_size
        device = self.device

        # Random starting position and zero velocity
        self._pos = (
            torch.rand(*batch_size, 1, device=device) * 2 * self.MAX_POS
            - self.MAX_POS
        )
        self._vel = torch.zeros(*batch_size, 1, device=device)

        obs = torch.cat([self._pos, self._vel], dim=-1)
        return TensorDict(
            {"observation": obs},
            batch_size=batch_size,
            device=device,
        )

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        """Apply action and return next observation, reward, and done flag.

        The output tensordict should contain *only* the next-step data.
        TorchRL merges it with the input automatically.
        """
        action = tensordict["action"]  # shape (*batch, 1)

        # Simple Euler integration
        self._vel = (self._vel + action * 0.1).clamp(-2.0, 2.0)
        self._pos = (self._pos + self._vel * 0.1).clamp(-self.MAX_POS, self.MAX_POS)

        obs = torch.cat([self._pos, self._vel], dim=-1)
        # Reward: negative squared distance from origin
        reward = -(self._pos**2).sum(dim=-1, keepdim=True)
        done = (self._pos.abs() < 0.1)  # shape (*batch, 1)

        return TensorDict(
            {
                "observation": obs,
                "reward": reward,
                "done": done,
                "terminated": done.clone(),
            },
            batch_size=tensordict.batch_size,
            device=self.device,
        )

    def _set_seed(self, seed: int | None) -> None:
        """Seed the environment's internal RNG."""
        if seed is not None:
            torch.manual_seed(seed)
10
_step must return only the next-step values. Never put the action or the current observation into the output — TorchRL constructs the full transition by merging _step’s output with the input tensordict.
11
Verify with check_env_specs
12
check_env_specs runs a handful of resets and steps and verifies that the actual outputs match every declared spec. Always call it before connecting your environment to a collector or loss module.
13
from torchrl.envs.utils import check_env_specs

env = CorridorEnv()
check_env_specs(env)     # raises AssertionError with a helpful message if something is wrong
print("Specs OK!")

# Quick smoke-test: run a random rollout
td = env.rollout(max_steps=10)
print(td)
14
check_env_specs checks:
15
  • Every output key is present and matches its declared spec shape and dtype.
  • batch_size is consistent throughout.
  • Done/terminated flags are boolean tensors.
  • 16
    Wrap a third-party Gym environment
    17
    If the underlying simulator is already wrapped in a Gymnasium API you do not need to subclass EnvBaseGymEnv handles the conversion automatically.
    18
    from torchrl.envs import GymEnv
    
    # GymEnv reads observation_spec, action_spec, etc. from the Gym spaces
    env = GymEnv("Pendulum-v1", device="cpu")
    check_env_specs(env)
    print(env.observation_spec)   # Composite with a BoundedContinuous "observation" entry
    print(env.action_spec)        # BoundedContinuous action
    
    19
    For environments that return an info dict, use default_info_dict_reader to declare which keys should flow into TensorDicts:
    20
    from torchrl.envs import default_info_dict_reader, GymWrapper
    import gymnasium as gym
    
    base_env = gym.make("Pendulum-v1")
    env = GymWrapper(base_env)
    reader = default_info_dict_reader(["my_metric"])
    env.set_info_dict_reader(reader)
    
    21
    Add preprocessing with TransformedEnv
    22
    TransformedEnv wraps any EnvBase and allows transforms to be appended dynamically. Transforms run after _step and _reset, so the policy and collector always see transformed observations.
    23
    from torchrl.envs import (
        DoubleToFloat,
        ObservationNorm,
        RewardScaling,
        StepCounter,
        TransformedEnv,
    )
    
    # Wrap our custom environment
    base_env = CorridorEnv()
    env = TransformedEnv(base_env)
    
    # Normalize observations to roughly [-1, 1] using spec bounds
    env.append_transform(ObservationNorm(in_keys=["observation"], standard_normal=True))
    # Cast float64 observations to float32 (required if your network uses float32)
    env.append_transform(DoubleToFloat(in_keys=["observation"]))
    # Scale rewards
    env.append_transform(RewardScaling(loc=0.0, scale=0.1))
    # Track episode step count
    env.append_transform(StepCounter(max_steps=200))
    
    check_env_specs(env)
    
    24
    You can call env.append_transform(transform) at any time — even after the environment has already been used. The transform stack is applied lazily.
    25
    Use specs for network auto-sizing
    26
    The specs on a TransformedEnv reflect the post-transform shapes. Use them to wire your network without hard-coding sizes:
    27
    import torch.nn as nn
    from tensordict.nn import TensorDictModule
    from torchrl.modules import MLP
    
    obs_size = env.observation_spec["observation"].shape[-1]   # after transforms
    act_size = env.action_spec.shape[-1]
    
    policy = TensorDictModule(
        MLP(in_features=obs_size, out_features=act_size, num_cells=[64, 64]),
        in_keys=["observation"],
        out_keys=["action"],
    )
    
    # Verify the policy works end-to-end
    td = env.reset()
    td = policy(td)
    print(td["action"])
    
    28
    BatchedEnvs: running multiple copies in parallel
    29
    TorchRL provides ParallelEnv and SerialEnv to vectorise any EnvBase subclass. They accept a constructor (not an instance) because worker processes may not share memory:
    30
    from torchrl.envs import ParallelEnv, SerialEnv
    
    # Run 4 environments in separate processes
    parallel_env = ParallelEnv(num_workers=4, create_env_fn=lambda: CorridorEnv())
    parallel_env.reset()
    parallel_env.rollout(max_steps=10)
    parallel_env.close()
    
    # Run 4 environments sequentially (useful for debugging)
    serial_env = SerialEnv(num_workers=4, create_env_fn=lambda: CorridorEnv())
    

    Common spec types

    Spec classUse caseKey parameters
    BoundedContinuousContinuous observations / actions with explicit boundslow, high, shape, dtype
    UnboundedContinuous tensors with no bounds (rewards, states)shape, dtype
    CategoricalDiscrete actions or integer observationsn (number of categories), shape
    CompositeGroups multiple specs under named keys**{key: spec}, shape

    Full custom environment template

    from __future__ import annotations
    import torch
    from tensordict import TensorDict, TensorDictBase
    from torchrl.data import BoundedContinuous, Categorical, Composite, Unbounded
    from torchrl.envs import EnvBase
    from torchrl.envs.utils import check_env_specs
    
    
    class MyEnv(EnvBase):
        def __init__(self, device=None, batch_size=()):
            super().__init__(device=device, batch_size=batch_size)
    
            self.observation_spec = Composite(
                observation=Unbounded(shape=(*batch_size, 4), device=device),
                shape=batch_size,
            )
            self.action_spec = Categorical(n=2, shape=(*batch_size,), device=device)
            self.reward_spec = Unbounded(shape=(*batch_size, 1), device=device)
    
        def _reset(self, tensordict, **kwargs):
            obs = torch.zeros(*self.batch_size, 4, device=self.device)
            return TensorDict({"observation": obs}, batch_size=self.batch_size)
    
        def _step(self, tensordict):
            action = tensordict["action"]
            obs = torch.randn(*self.batch_size, 4, device=self.device)
            reward = torch.zeros(*self.batch_size, 1, device=self.device)
            done = torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device)
            return TensorDict(
                {"observation": obs, "reward": reward, "done": done, "terminated": done},
                batch_size=self.batch_size,
            )
    
        def _set_seed(self, seed):
            if seed is not None:
                torch.manual_seed(seed)
    
    
    # Validate
    env = MyEnv()
    check_env_specs(env)
    rollout = env.rollout(max_steps=5)
    print(rollout)
    

    Build docs developers (and LLMs) love