Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL makes it straightforward to run many environment instances at once. All vectorised environment classes derive from BatchedEnvBase, which itself extends EnvBase, so they honour the identical reset / step / rollout / check_env_specs interface. Observations, actions, rewards, and done flags are stacked along a leading batch dimension and returned as a single TensorDict, keeping downstream policy and training code agnostic to whether it is talking to one environment or a thousand.
All classes documented on this page are importable from torchrl.envs:
from torchrl.envs import SerialEnv, ParallelEnv, EnvCreator
from torchrl.envs import AsyncEnvPool, ProcessorAsyncEnvPool, ThreadingAsyncEnvPool

The create_env_fn Pattern

Every BatchedEnvBase subclass accepts a create_env_fn argument — a callable (or list of callables) that returns a new EnvBase instance each time it is invoked. Workers call this factory to instantiate their private copy of the environment without sharing state.
import torchrl.envs as envs

# Simplest form: a lambda
create_env_fn = lambda: envs.GymEnv("Pendulum-v1", device="cpu")

# Or a named function
def make_env():
    return envs.GymEnv("CartPole-v1")

# Or different environments per worker
create_env_fns = [
    lambda: envs.GymEnv("CartPole-v1"),
    lambda: envs.GymEnv("Pendulum-v1"),
]
Lambda functions cannot be pickled. When using ParallelEnv (multiprocess), wrap your factory in EnvCreator or define it as a module-level function so it can be sent to worker processes.

BatchedEnvBase

BatchedEnvBase is the abstract parent of SerialEnv and ParallelEnv. It initialises shared infrastructure — spec negotiation across workers, shared memory or memory-map buffers, and metadata caching — before the workers are started.

Constructor Parameters


SerialEnv

SerialEnv creates and steps all environment instances sequentially within the same process. It shares the BatchedEnvBase interface but incurs no IPC overhead, making it the right choice for lightweight environments, debugging, or when the GIL prevents true parallelism anyway.
from torchrl.envs import SerialEnv, GymEnv

env = SerialEnv(
    num_workers=4,
    create_env_fn=lambda: GymEnv("Pendulum-v1"),
)

td = env.reset()
print(td.shape)             # torch.Size([4])
print(td["observation"].shape)  # torch.Size([4, 3])

td = env.rollout(max_steps=200)
print(td.shape)             # torch.Size([4, 200])
All four environments step in turn inside a single Python thread. The returned TensorDict has a leading dimension of num_workers.

Key Characteristics

  • No serialisation cost — environments live in the same process.
  • Easy to debug — standard Python breakpoints and profilers work.
  • No shared memory required — each env writes to its own tensor.
  • Sequential execution — one environment completes before the next starts; no speedup on multi-core machines.

ParallelEnv

ParallelEnv spawns one subprocess per worker and exchanges data via shared memory (shared_memory=True) or memory-mapped files (memmap=True). The main process sends action TensorDicts to all workers simultaneously, waits for results, and stacks them.
from torchrl.envs import ParallelEnv, GymEnv, check_env_specs

def make_env():
    return GymEnv("HalfCheetah-v4")

# Always check_env_specs before ParallelEnv!
check_env_specs(make_env())

env = ParallelEnv(
    num_workers=8,
    create_env_fn=make_env,
    device="cpu",
)

td = env.reset()
print(td.shape)             # torch.Size([8])

td = env.rollout(max_steps=500)
print(td.shape)             # torch.Size([8, 500])

env.close()
Always call check_env_specs on a single instance of your environment before wrapping it in ParallelEnv. Shared-memory buffers are pre-allocated from the specs; a spec mismatch causes a hard crash inside worker processes that is difficult to diagnose.

Start Method

By default, TorchRL selects "spawn" on macOS / Windows and "fork" on Linux. Override with mp_start_method:
env = ParallelEnv(4, make_env, mp_start_method="spawn")

Worker Timeout

Workers that are idle for more than BATCHED_PIPE_TIMEOUT seconds are considered dead and raise. Control this via the environment variable:
export BATCHED_PIPE_TIMEOUT=3600   # 1 hour timeout

Configuring Parallel Execution After Construction

Use configure_parallel to adjust worker parameters before the environment is started (before the first reset / step):
env = ParallelEnv(4, make_env)
env.configure_parallel(
    use_buffers=True,
    num_threads=4,
    num_sub_threads=2,
    non_blocking=True,
)
td = env.reset()   # workers start here with the new config

EnvCreator

EnvCreator wraps an arbitrary callable so it can be safely pickled and sent to worker subprocesses. It is the recommended replacement for lambdas in multiprocessing contexts. When the factory builds a TransformedEnv with VecNorm, EnvCreator also wires up the shared-memory pointers so all workers share the same running statistics.
from torchrl.envs import EnvCreator, ParallelEnv, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm

def make_env():
    return TransformedEnv(
        GymEnv("HalfCheetah-v4"),
        VecNorm(in_keys=["observation"]),
    )

creator = EnvCreator(make_env)
env = ParallelEnv(4, creator)
td = env.rollout(max_steps=100)

get_env_metadata

get_env_metadata constructs an EnvMetaData snapshot from a factory function without keeping the environment alive. Useful when you want to inspect specs before committing to launching workers.
from torchrl.envs import get_env_metadata, GymEnv

meta = get_env_metadata(lambda: GymEnv("CartPole-v1"))
print(meta.specs)
print(meta.batch_size)

Async Environment Pools

For use cases that need even finer control over execution scheduling, TorchRL provides three async pool classes:

AsyncEnvPool

Abstract base for asynchronous environment pools. Manages a pool of workers that accept step requests without blocking and deliver results when ready.

ProcessorAsyncEnvPool

An AsyncEnvPool backed by a multiprocessing.Pool. Each worker runs in its own process.
from functools import partial
from torchrl.envs import GymEnv, ProcessorAsyncEnvPool

pool = ProcessorAsyncEnvPool(
    [partial(GymEnv, "Pendulum-v1")] * 4,
    backend="multiprocessing",
)

ThreadingAsyncEnvPool

An AsyncEnvPool backed by a ThreadPoolExecutor. All workers run in the same process using threads. Best suited for I/O-bound or GIL-releasing environments (e.g., environments with native C extensions).
from functools import partial
from torchrl.envs import GymEnv, ThreadingAsyncEnvPool

pool = ThreadingAsyncEnvPool(
    [partial(GymEnv, "Pendulum-v1")] * 4,
    backend="threading",
)

TensorDict Structure from Vectorised Envs

When num_workers=N, every TensorDict returned by reset or step has a leading batch dimension of N. Rollouts add an additional time dimension, producing shape [N, T].
from torchrl.envs import ParallelEnv, GymEnv

env = ParallelEnv(4, lambda: GymEnv("Pendulum-v1"))
td = env.rollout(max_steps=50)

# Shapes
print(td.shape)                        # torch.Size([4, 50])
print(td["action"].shape)              # torch.Size([4, 50, 1])
print(td["next", "observation"].shape) # torch.Size([4, 50, 3])
print(td["next", "reward"].shape)      # torch.Size([4, 50, 1])
print(td["next", "done"].shape)        # torch.Size([4, 50, 1])

env.close()
The "time" dimension name is attached to the last rollout dimension, enabling named-dim operations:
total_rewards = td["next", "reward"].sum("time")   # torch.Size([4, 1])

Partial Resets

When break_when_any_done=False is passed to rollout, done environments are reset automatically while others continue stepping. This partial-reset mode is the standard for on-policy data collection with vectorised environments.
td = env.rollout(
    max_steps=1000,
    break_when_any_done=False,   # individual envs reset silently on done
)
print(td.shape)   # torch.Size([4, 1000])  — guaranteed length

Code Examples

from torchrl.envs import SerialEnv, GymEnv

def make_env():
    return GymEnv("CartPole-v1")

env = SerialEnv(4, make_env)

# Random-policy rollout
td = env.rollout(max_steps=200, break_when_any_done=False)
print(f"Shape: {td.shape}")               # [4, 200]
print(f"Mean reward: {td['next', 'reward'].mean():.4f}")

env.close()

Tips and Common Pitfalls

Always check specs first

Run check_env_specs(make_env()) on a single environment before wrapping in ParallelEnv. Shared-memory buffer sizes are fixed at construction; a spec mismatch causes obscure worker crashes.

Use EnvCreator for lambdas

Lambdas cannot be pickled by the default pickle module. Wrap them in EnvCreator or define the factory at module level so multiprocessing can serialise it.

fork vs spawn

On Linux, mp_start_method="fork" is fastest. On macOS and Windows use "spawn". Never use "fork" with CUDA — it corrupts GPU state in child processes.

Partial resets for data collection

Pass break_when_any_done=False to rollout so individual done environments auto-reset while others continue, giving you a guaranteed [N, T] shaped batch every time.

Build docs developers (and LLMs) love