Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Data collection is the primary bottleneck in on-policy reinforcement learning: the policy must interact with the environment to generate experience before any gradient update can happen. TorchRL provides a hierarchy of collector classes — from a single-process Collector to a multi-machine RayCollector — that all expose the same iterator interface. You can prototype on a laptop with Collector, then switch to MultiSyncCollector for a multi-core workstation or RayCollector for a cluster, with almost no changes to the training loop. This tutorial walks up the hierarchy, showing when and how to use each level.
1
Single-process baseline with Collector
2
Collector is the simplest option: policy and environment live in the same process. It is the right choice for debugging and for environments that run very fast.
3
import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.collectors import Collector
from torchrl.envs import GymEnv

# A minimal linear policy for demonstration
policy = TensorDictModule(
    nn.Linear(3, 1),
    in_keys=["observation"],
    out_keys=["action"],
)

collector = Collector(
    create_env_fn=lambda: GymEnv("Pendulum-v1", device="cpu"),
    policy=policy,
    frames_per_batch=200,
    total_frames=10_000,
    device="cpu",
    max_frames_per_traj=50,    # reset after 50 steps
)

for data in collector:
    # data is a TensorDict with shape (200,)
    print(data.shape, data.keys())
    break

collector.shutdown()
4
Collector also accepts an already-constructed environment instance instead of a callable:
5
env = GymEnv("Pendulum-v1")
collector = Collector(create_env_fn=env, policy=policy, frames_per_batch=200, total_frames=2000)
6
Always call collector.shutdown() when done. This closes the underlying environment and releases any shared memory. Using a context manager (with statement) is not currently supported — use an explicit try/finally block in production code.
7
Multiprocess with MultiSyncCollector
8
MultiSyncCollector spawns one subprocess per entry in create_env_fn. All workers collect data in parallel, but the main process waits for every worker to finish before yielding the next batch. This is ideal for on-policy algorithms (PPO, A2C) where the training step must see fresh data from all workers.
9
from torchrl.collectors import MultiSyncCollector

if __name__ == "__main__":  # required for multiprocessing on Windows/macOS
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")

    collector = MultiSyncCollector(
        create_env_fn=[env_maker, env_maker, env_maker, env_maker],  # 4 workers
        policy=policy,
        frames_per_batch=400,    # 100 frames per worker per batch
        total_frames=40_000,
        max_frames_per_traj=50,
        device="cpu",
        storing_device="cpu",
        cat_results="stack",     # stack results along a new dim (vs. "cat" to concatenate)
    )

    for i, data in enumerate(collector):
        # data has shape (400,) — all workers' data concatenated or stacked
        print(f"batch {i}: {data.shape}")
        # After each batch, push updated weights to workers
        collector.update_policy_weights_()
        if i == 4:
            break

    collector.shutdown()
10
Python’s multiprocessing requires the collector construction to be inside an if __name__ == "__main__": guard on Windows and macOS. On Linux (fork start method) this is not strictly required, but it remains good practice.
11
Multiprocess with MultiAsyncCollector
12
MultiAsyncCollector is the asynchronous counterpart: workers never wait for each other. The main process receives batches as soon as any worker finishes collecting one. This maximises throughput at the cost of policy staleness — a worker may be running an older policy version when its batch arrives.
13
MultiAsyncCollector suits off-policy algorithms (SAC, TD3, DQN) where data from a slightly stale policy is still usable.
14
from torchrl.collectors import MultiAsyncCollector

if __name__ == "__main__":
    collector = MultiAsyncCollector(
        create_env_fn=[env_maker, env_maker, env_maker, env_maker],
        policy=policy,
        frames_per_batch=200,
        total_frames=40_000,
        max_frames_per_traj=50,
        device="cpu",
        storing_device="cpu",
        cat_results="stack",
    )

    for i, data in enumerate(collector):
        # Batches arrive from whichever worker finishes first
        print(f"batch from worker, shape={data.shape}")
        # Update only the worker that just delivered this batch
        collector.update_policy_weights_()
        if i == 10:
            break

    collector.shutdown()
15
Sync vs async: a quick comparison
16
MultiSyncCollector
Synchronous collection — all workers finish before the main process receives data.
  • All data in a batch was collected with the same policy version.
  • Lower throughput because fast workers idle while slow ones finish.
  • Natural fit for on-policy algorithms (PPO, A2C) where fresh data is required.
  • Simpler to reason about — each training step sees a clean, uniform batch.
from torchrl.collectors import MultiSyncCollector
# All 4 workers must return before training starts
collector = MultiSyncCollector(
    create_env_fn=[env_maker] * 4,
    policy=policy,
    frames_per_batch=400,
    total_frames=100_000,
)
MultiAsyncCollector
Asynchronous collection — the main process processes each batch as it arrives.
  • Workers run continuously; no idle time.
  • Batches may come from different policy versions (stale-data problem).
  • Higher throughput for environments with variable episode lengths.
  • Natural fit for off-policy algorithms (SAC, TD3, DQN) with a replay buffer.
from torchrl.collectors import MultiAsyncCollector
# Workers run independently; main process takes data as it arrives
collector = MultiAsyncCollector(
    create_env_fn=[env_maker] * 4,
    policy=policy,
    frames_per_batch=200,
    total_frames=100_000,
)
17
Weight synchronization
18
All collector classes expose update_policy_weights_() to push updated network weights from the training process into workers. For multiprocess collectors the transfer uses shared memory (on the same machine) or torch.distributed for cross-machine setups.
19
# Basic usage: push current policy weights to all workers
collector.update_policy_weights_()

# Explicit: pass a policy module
collector.update_policy_weights_(policy)

# For async collectors: update only the worker that just returned a batch
# (handled automatically when you call update_policy_weights_() after each batch)
collector.update_policy_weights_()
20
For MultiSyncCollector, calling update_policy_weights_() after every collected batch is the standard pattern for on-policy training — it ensures the next batch is always generated with the latest policy.
21
Distributed collection across machines with RayCollector
22
RayCollector uses Ray to distribute collection across a Ray cluster. Each remote collector is a separate Ray actor; the main process coordinates them. The interface is identical to local collectors.
23
pip install ray torchrl
24
from torchrl.collectors.distributed import RayCollector
from torchrl.collectors import Collector   # used as the per-worker collector class

if __name__ == "__main__":
    # Default: RayCollector auto-detects an existing cluster or starts a local Ray instance
    collector = RayCollector(
        create_env_fn=lambda: GymEnv("Pendulum-v1"),
        policy=policy,
        frames_per_batch=400,
        total_frames=40_000,
        num_collectors=4,          # 4 remote Ray actors
        sync=True,                 # True = synchronous, False = async (first-ready)
        collector_class=Collector, # per-worker collector type
        collector_kwargs={
            "max_frames_per_traj": 50,
        },
        # Resource spec passed to ray.remote() for each actor
        remote_configs={
            "num_cpus": 1,
            "num_gpus": 0,
            "memory": 2 * 1024 ** 3,
        },
        update_after_each_batch=True,   # auto-sync weights after every batch
    )

    for i, data in enumerate(collector):
        print(f"batch {i}: {data.shape}")
        if i == 5:
            break

    collector.shutdown()
25
Connecting to an existing Ray cluster
26
collector = RayCollector(
    create_env_fn=lambda: GymEnv("Pendulum-v1"),
    policy=policy,
    frames_per_batch=400,
    total_frames=200_000,
    num_collectors=16,
    ray_init_config={
        "address": "ray://head-node:10001",   # address of the Ray head node
    },
    remote_configs={
        "num_cpus": 2,
        "num_gpus": 0.5,
        "memory": 4 * 1024 ** 3,
    },
)
27
Distributed collection with RPCCollector
28
RPCCollector uses torch.distributed.rpc as the transport instead of Ray. It is appropriate when you are already running in a torch.distributed context (e.g., on a SLURM cluster with PyTorch’s native launcher) and do not want to install Ray.
29
# RPCCollector is intended to be run from a launch script (torchrun / submitit)
# The snippet below shows the collector construction; see the TorchRL examples
# for a full launch script.

from torchrl.collectors.distributed import RPCCollector

collector = RPCCollector(
    create_env_fn=lambda: GymEnv("Pendulum-v1"),
    policy=policy,
    frames_per_batch=400,
    total_frames=100_000,
    num_collectors=4,
    sync=True,
)

for data in collector:
    collector.update_policy_weights_()

collector.shutdown()
30
RPCCollector requires torch.distributed.rpc to be initialised before instantiation. The torchrun launcher (or submitit for SLURM) handles this automatically. Refer to the TorchRL RPC example for the full multi-node setup.
31
A complete multi-worker PPO loop
32
The example below shows how the collector hierarchy slots into a real training loop. Switching between Collector, MultiSyncCollector, and RayCollector only requires changing the first few lines.
33
"""Multi-worker PPO training loop (sync variant)."""
from __future__ import annotations

import torch
import torch.nn as nn
import tqdm
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.collectors import MultiSyncCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import (
    ClipTransform, DoubleToFloat, ExplorationType,
    GymEnv,
    RewardSum, StepCounter, TransformedEnv, VecNorm,
)
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE

ENV_NAME = "HalfCheetah-v4"
N_WORKERS = 4
FRAMES_PER_BATCH = 2048    # total across all workers
MINI_BATCH_SIZE = 256
TOTAL_FRAMES = 1_000_000
PPO_EPOCHS = 10

def make_env():
    env = GymEnv(ENV_NAME, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
    env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
    env.append_transform(RewardSum())
    env.append_transform(StepCounter())
    env.append_transform(DoubleToFloat(in_keys=["observation"]))
    return env

proof_env = make_env()
obs_dim = proof_env.observation_spec["observation"].shape[-1]
act_dim = proof_env.action_spec_unbatched.shape[-1]

policy_mlp = nn.Sequential(
    MLP(obs_dim, activation_class=nn.Tanh, out_features=act_dim, num_cells=[64, 64]),
    AddStateIndependentNormalScale(act_dim, scale_lb=1e-8),
)
actor = ProbabilisticActor(
    TensorDictModule(policy_mlp, in_keys=["observation"], out_keys=["loc", "scale"]),
    in_keys=["loc", "scale"],
    spec=proof_env.full_action_spec_unbatched,
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": proof_env.action_spec_unbatched.space.low,
        "high": proof_env.action_spec_unbatched.space.high,
        "tanh_loc": False,
    },
    return_log_prob=True,
    default_interaction_type=ExplorationType.RANDOM,
)
critic = ValueOperator(
    MLP(obs_dim, activation_class=nn.Tanh, out_features=1, num_cells=[64, 64]),
    in_keys=["observation"],
)

adv_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic, average_gae=False)
loss_module = ClipPPOLoss(
    actor_network=actor, critic_network=critic,
    clip_epsilon=0.2, entropy_coeff=0.01, critic_coeff=0.5, normalize_advantage=True,
)
optim = group_optimizers(
    torch.optim.Adam(actor.parameters(), lr=3e-4, eps=1e-5),
    torch.optim.Adam(critic.parameters(), lr=3e-4, eps=1e-5),
)

if __name__ == "__main__":
    # ── swap this line to scale: ──────────────────────────────────────────────
    # Single process:  Collector(make_env, actor, ...)
    # Multi-process:   MultiSyncCollector([make_env]*N_WORKERS, actor, ...)
    # Multi-machine:   RayCollector(make_env, actor, num_collectors=N, ...)
    # ─────────────────────────────────────────────────────────────────────────
    collector = MultiSyncCollector(
        create_env_fn=[make_env] * N_WORKERS,
        policy=actor,
        frames_per_batch=FRAMES_PER_BATCH,
        total_frames=TOTAL_FRAMES,
        device="cpu",
        storing_device="cpu",
    )

    data_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(FRAMES_PER_BATCH),
        sampler=SamplerWithoutReplacement(),
        batch_size=MINI_BATCH_SIZE,
    )

    pbar = tqdm.tqdm(total=TOTAL_FRAMES)
    for data in collector:
        pbar.update(data.numel())

        for _ in range(PPO_EPOCHS):
            with torch.no_grad():
                data = adv_module(data)
            data_buffer.extend(data.reshape(-1))
            for batch in data_buffer:
                optim.zero_grad(set_to_none=True)
                loss = loss_module(batch)
                (
                    loss["loss_objective"]
                    + loss["loss_critic"]
                    + loss["loss_entropy"]
                ).backward()
                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
                optim.step()

        # Synchronise updated weights with all worker processes
        collector.update_policy_weights_()

        ep_r = data["next", "episode_reward"][data["next", "done"]]
        if len(ep_r):
            pbar.set_description(f"reward={ep_r.mean().item():.2f}")

    collector.shutdown()

Collector comparison

CollectorWorkersSynchronisationBest for
Collector1 (same process)n/aDebugging, fast envs
MultiSyncCollectorN (subprocesses)All workers finish before main gets batchOn-policy (PPO, A2C)
MultiAsyncCollectorN (subprocesses)Main gets batches as they arriveOff-policy (SAC, DQN)
RayCollectorN (Ray actors)Configurable (sync=True/False)Multi-machine clusters
RPCCollectorN (RPC workers)Configurable (sync=True/False)SLURM / torchrun setups

Weight update patterns

# Pattern 1: Update after every collected batch (on-policy)
for data in collector:
    train(data)
    collector.update_policy_weights_()

# Pattern 2: Update every K batches (off-policy / replay buffer)
for i, data in enumerate(collector):
    replay_buffer.extend(data.reshape(-1))
    train_from_buffer(replay_buffer)
    if i % 5 == 0:
        collector.update_policy_weights_()

# Pattern 3: Automatic update (RayCollector only)
collector = RayCollector(..., update_after_each_batch=True)
for data in collector:
    train(data)    # weights already synced before next iteration

Build docs developers (and LLMs) love