Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL models every reinforcement-learning environment as a subclass of EnvBase, which extends torch.nn.Module and exposes a consistent TensorDict-in / TensorDict-out interface. All observations, actions, rewards, and done flags are packed into TensorDict objects, making it straightforward to move data between devices, batch across environments, and compose transforms without any environment-specific boilerplate. This page covers the full public API of EnvBase, the companion spec classes that declare data shapes and dtypes, the GymLikeEnv adapter for gym-compatible backends, and the EnvCreator / EnvMetaData utilities used by vectorised environments.
Every method documented here lives in torchrl.envs. You can import any symbol directly from the top-level package: from torchrl.envs import EnvBase, GymEnv, check_env_specs.

EnvBase

EnvBase is the abstract base class for all TorchRL environments. It inherits from torch.nn.Module, so parameters and buffers follow standard PyTorch conventions. Concrete environments implement the private _reset and _step methods; the public reset / step wrappers add validation, spec-locking, and housekeeping that should never be overridden.
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1", device="cpu")
td = env.reset()
print(td)
# TensorDict with observation, done, terminated, truncated

td = env.step(td)
print(td["next"])
# TensorDict with next observation, reward, done flags

Constructor


Core Methods

reset(tensordict=None, *, set_state=None, **kwargs) → TensorDictBase

Resets the environment and returns a TensorDict populated with initial observations. The public reset should not be overridden — implement _reset in subclasses instead.
td = env.reset()
# td contains observation keys defined by env.observation_spec

step(tensordict) → TensorDictBase

Executes one environment step. The input tensordict must contain the action under the key(s) declared by env.action_spec. Results land in the "next" sub-TensorDict of the returned object.
td = env.reset()
td["action"] = env.action_spec.rand()
td = env.step(td)

obs_next = td["next", "observation"]
reward   = td["next", "reward"]
done     = td["next", "done"]

rollout(max_steps, policy=None, *, auto_reset=True, break_when_any_done=True, ...) → TensorDictBase

Runs a full trajectory (up to max_steps) and returns the collected data as a single stacked TensorDict with a time dimension.
from torchrl.envs import GymEnv

env = GymEnv("CartPole-v1")
td = env.rollout(max_steps=200)
print(td.shape)            # torch.Size([200])
print(td["next", "reward"].sum())

rand_step(tensordict=None) → TensorDictBase

Sample a random action from action_spec and execute one step.
td = env.rand_step()

check_env_specs(*args, **kwargs)

Verify that the environment’s specs are internally consistent and that actual reset / step outputs match the declared specs. Raises on any mismatch.
from torchrl.envs import check_env_specs, GymEnv

env = GymEnv("HalfCheetah-v4")
check_env_specs(env)  # raises if specs are inconsistent
Always run check_env_specs before placing an environment inside ParallelEnv. Inconsistent specs cause hard-to-debug crashes in worker processes.

close(*, raise_if_closed=True)

Release all resources held by the environment (file handles, subprocess workers, GPU memory). After calling close, the environment is marked as closed and subsequent calls to step or reset will raise.

fake_tensordict() → TensorDictBase

Return a zero-filled TensorDict whose structure exactly matches what reset / step would produce. Useful for spec validation and pre-allocating buffers.

Spec Attributes

TorchRL uses specs to declare the shape, dtype, domain, and device of every tensor in the environment interface. Specs are TensorSpec instances stored on the environment and locked after construction.
AttributeTypeDescription
observation_specCompositeFull specification of all observation tensors. Alias for full_observation_spec.
action_specTensorSpecLeaf spec when there is a single action tensor; otherwise full_action_spec.
reward_specTensorSpecLeaf spec when there is a single reward; otherwise full_reward_spec.
done_specCompositeComposite spec containing at minimum "done" and "terminated" leaves.
state_specCompositeInputs that are not actions (e.g., hidden states).
full_action_specCompositeComplete composite of all action entries.
full_observation_specCompositeComplete composite of all observation entries.
full_reward_specCompositeComplete composite of all reward entries.
full_done_specCompositeComplete composite of all done entries.
full_state_specCompositeComplete composite of all state inputs.
env = GymEnv("HalfCheetah-v4")

print(env.observation_spec)
# Composite(observation: UnboundedContinuous(shape=torch.Size([17]), ...))

print(env.action_spec)
# BoundedContinuous(shape=torch.Size([6]), low=-1.0, high=1.0, ...)

# Sample a valid observation
obs = env.observation_spec.rand()

# Zero all done flags
done_td = env.done_spec.zero()

Spec Classes

Specs describe the domain of each tensor in the environment. They live in torchrl.data but are re-exported through torchrl.envs.

Composite

A dictionary-like container that groups multiple named specs. Mirrors TensorDict for specs.
from torchrl.data import Composite, Bounded, Unbounded

spec = Composite(
    observation=Unbounded(shape=(8,), dtype=torch.float32),
    action=Bounded(low=-1.0, high=1.0, shape=(2,)),
)
sample = spec.rand()          # TensorDict with random tensors
zero   = spec.zero()          # TensorDict of zeros

Bounded

A continuous or discrete spec with explicit lower and upper bounds.
from torchrl.data import Bounded
spec = Bounded(low=-1.0, high=1.0, shape=(6,), dtype=torch.float32)
print(spec.rand())    # Tensor of shape [6], values in [-1, 1]

Unbounded

A continuous spec with no range constraint. Used for most observations and rewards.
from torchrl.data import Unbounded
spec = Unbounded(shape=(17,), dtype=torch.float64)

Categorical

An integer-valued spec representing a categorical action or observation with n possible values.
from torchrl.data import Categorical
spec = Categorical(n=4, shape=(), dtype=torch.int64)
print(spec.rand())    # tensor(2) or similar

OneHot

Like Categorical but samples are one-hot encoded boolean tensors of shape (..., n).
from torchrl.data import OneHot
spec = OneHot(n=4, shape=(4,), dtype=torch.bool)
print(spec.rand())    # e.g. tensor([False, True, False, False])

GymLikeEnv

GymLikeEnv is an intermediate abstract class that sits between EnvBase and concrete wrappers around gym-style backends (Gymnasium, DM Control, etc.). It standardises how info dictionaries from the underlying environment are ingested.
from torchrl.envs.gym_like import GymLikeEnv
Key features:
  • Implements _step by calling the underlying _gym_step and unpacking (obs, reward, terminated, truncated, info) tuples.
  • Supports set_info_dict_reader(info_dict_reader) to attach a custom default_info_dict_reader that maps info keys into the output TensorDict.
  • frame_skip parameter repeats actions and accumulates rewards automatically.

default_info_dict_reader

A callable that maps selected keys from the environment info dict into the output TensorDict.
from torchrl.envs import GymWrapper, default_info_dict_reader
import gymnasium as gym

env = GymWrapper(gym.make("HalfCheetah-v4"))
reader = default_info_dict_reader(["x_position", "x_velocity"])
env.set_info_dict_reader(reader)

td = env.rollout(5)
print("x_velocity" in td["next"].keys())  # True

EnvMetaData

EnvMetaData is a lightweight serialisable snapshot of an environment’s specs, batch size, device, and a sample TensorDict. It is used internally by ParallelEnv and SerialEnv to propagate environment metadata to worker processes without instantiating the full environment in the main process.
from torchrl.envs import EnvMetaData, GymEnv

env = GymEnv("Pendulum-v1")
meta = EnvMetaData.metadata_from_env(env)
print(meta.batch_size)   # torch.Size([])
print(meta.device)       # device(type='cpu')
Key attributes: tensordict, specs, batch_size, device, batch_locked, supports_set_state.

EnvCreator and get_env_metadata

EnvCreator wraps a callable environment factory so that it can be safely pickled and sent to subprocess workers. When the factory uses a VecNorm transform, EnvCreator also wires up the shared-memory pointers so all workers stay synchronised.
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm

env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
creator = EnvCreator(env_fn)

penv = ParallelEnv(4, creator)
td = penv.rollout(10)
get_env_metadata(env_fn, **kwargs) constructs EnvMetaData from a factory without requiring EnvCreator:
from torchrl.envs import get_env_metadata, GymEnv

meta = get_env_metadata(lambda: GymEnv("CartPole-v1"))
print(meta.specs)

Full Environment Lifecycle Example

import torch
from torchrl.envs import GymEnv, check_env_specs
from torchrl.envs.transforms import TransformedEnv, RewardScaling, DoubleToFloat

# 1. Create the base environment
base_env = GymEnv("HalfCheetah-v4", device="cpu")

# 2. Wrap with transforms
env = TransformedEnv(
    base_env,
    RewardScaling(loc=0.0, scale=0.1),
)
env = TransformedEnv(env, DoubleToFloat())

# 3. Verify specs before training
check_env_specs(env)

# 4. Collect a trajectory with a random policy
td = env.rollout(max_steps=500)

print(f"Trajectory shape : {td.shape}")
print(f"Total reward      : {td['next', 'reward'].sum().item():.2f}")

# 5. Close when done
env.close()

Build docs developers (and LLMs) love