Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

A collector owns the execution loop that runs a policy inside one or more environments and returns batches of trajectory data as TensorDicts. Rather than writing your own rollout loop, you hand a collector a policy, an environment constructor, and a batch size — it handles stepping, resetting, device movement, weight synchronization, and trajectory packaging. The result is an iterable that emits one TensorDict per iteration, ready to be sent to a replay buffer or consumed directly by an on-policy loss.

Why collectors exist

The alternative to collectors is a hand-rolled loop: call env.step(), accumulate tensors, move them to the right device, handle episode boundaries, repeat in multiple processes. That loop is easy to get wrong — mismatched devices, missing done-flag handling, synchronization bugs in multiprocess code. Collectors encapsulate all of that so your training loop can focus on the learning update.
Collector
  ├── create_env_fn   →  one or more environment instances
  ├── policy          →  any TensorDictModule (or plain nn.Module)
  ├── frames_per_batch  →  how many transitions to emit per iteration
  └── total_frames    →  when to stop (−1 = run forever)

The Collector class

Collector is the single-process, single-environment collector. It is the simplest entry point and suitable for local development and on-policy algorithms.
from torchrl.collectors import Collector
from torchrl.envs.libs.gym import GymEnv

collector = Collector(
    create_env_fn=lambda: GymEnv("HalfCheetah-v4"),
    policy=policy,
    frames_per_batch=1024,
    total_frames=1_000_000,
    device="cpu",
)

for batch in collector:
    # batch is a TensorDict with shape [1024].
    # Keys include "observation", "action", ("next", "observation"),
    # ("next", "reward"), ("next", "done"), etc.
    train_step(batch)

collector.shutdown()
total_frames must be divisible by frames_per_batch. Pass total_frames=-1 to create an endless collector that you break out of manually.

Constructor arguments

ArgumentDescription
create_env_fnCallable that returns an EnvBase instance, or an existing env
policyA TensorDictModule or any callable that accepts a TensorDict
frames_per_batchNumber of transitions emitted per __next__ call
total_framesTotal transitions before the collector is exhausted (-1 for infinite)
deviceConvenience device for both env and policy; overridden by env_device / policy_device
env_deviceDevice on which environment steps are executed
policy_deviceDevice on which the policy forward pass runs
storing_deviceDevice on which the emitted TensorDict is stored
max_frames_per_trajTruncate episodes at this many steps
compile_policyPass True or a dict of kwargs to torch.compile the policy
cudagraph_policyWrap the policy in CUDA graphs for faster inference
auto_register_policy_transformsRegister any env transforms on the policy automatically

Using Collector in an on-policy loop

import torch
from torchrl.collectors import Collector
from torchrl.envs import TransformedEnv, GymEnv, ObservationNorm, Compose
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

env = TransformedEnv(
    GymEnv("HalfCheetah-v4"),
    Compose(ObservationNorm(in_keys=["observation"])),
)

collector = Collector(
    create_env_fn=env,
    policy=actor,
    frames_per_batch=2048,
    total_frames=500_000,
)
loss_fn = ClipPPOLoss(actor_network=actor, critic_network=critic)
advantage_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95)
optimizer = torch.optim.Adam(loss_fn.parameters(), lr=3e-4)

for batch in collector:
    batch = advantage_fn(batch)           # compute GAE advantages in-place
    for epoch in range(10):
        mini = batch[torch.randperm(len(batch))[:256]]
        losses = loss_fn(mini)
        total = losses["loss_objective"] + losses["loss_critic"] + losses["loss_entropy"]
        optimizer.zero_grad()
        total.backward()
        optimizer.step()

collector.shutdown()

AsyncCollector

AsyncCollector runs the environment loop in a background thread while the main thread processes the previous batch. This overlaps simulation and learning for a modest throughput gain in single-environment settings.
from torchrl.collectors import AsyncCollector

collector = AsyncCollector(
    create_env_fn=lambda: GymEnv("HalfCheetah-v4"),
    policy=policy,
    frames_per_batch=1024,
    total_frames=1_000_000,
)

MultiSyncCollector

MultiSyncCollector spawns N worker processes, each running its own copy of the environment. The main process waits for all workers to return a batch before yielding, then broadcasts updated policy weights back. This is the right choice for synchronous on-policy training with large batch sizes.
from torchrl.collectors import MultiSyncCollector

collector = MultiSyncCollector(
    create_env_fn=[lambda: GymEnv("HalfCheetah-v4")] * 8,
    policy=policy,
    frames_per_batch=2048,
    total_frames=2_000_000,
    num_workers=8,
)

for batch in collector:
    # batch.batch_size == torch.Size([2048])
    # Contributions from all 8 workers are merged.
    losses = loss_fn(advantage_fn(batch))
    optimizer.step()
    # Push updated weights back to workers.
    collector.update_policy_weights_()

MultiAsyncCollector

MultiAsyncCollector also runs N worker processes but does not synchronize: workers return batches independently as soon as they are ready. The main process yields the first available batch, making off-policy training pipelines more efficient when simulation is slow and heterogeneous.
from torchrl.collectors import MultiAsyncCollector

collector = MultiAsyncCollector(
    create_env_fn=[lambda: GymEnv("HalfCheetah-v4")] * 8,
    policy=policy,
    frames_per_batch=256,   # per-worker batch size
    total_frames=2_000_000,
    num_workers=8,
)

for batch in collector:
    replay_buffer.extend(batch)
    if len(replay_buffer) > 10_000:
        sample = replay_buffer.sample(256)
        # Off-policy update ...
        collector.update_policy_weights_()
Use MultiSyncCollector for on-policy algorithms (PPO, A2C) and MultiAsyncCollector for off-policy algorithms (SAC, TD3, DQN) where stale data is acceptable.

AsyncBatchedCollector

AsyncBatchedCollector is similar to MultiAsyncCollector but batches the outputs of multiple environments together before yielding, reducing overhead when each individual env is cheap.

Weight synchronization

All multi-worker collectors expose update_policy_weights_(). Internally TorchRL uses a WeightUpdaterBase to copy parameters from the learner process to worker processes. Several updater implementations are available:

VanillaWeightUpdater

Copies parameters using shared memory or pickle. The default for MultiSyncCollector and MultiAsyncCollector.

MultiProcessedWeightUpdater

Uses a shared-memory tensor dict to broadcast weights without serialization overhead. Good for large models.

RayWeightUpdater

Syncs weights across Ray workers for distributed training on a Ray cluster.

RemoteModuleWeightUpdater

Pushes weights to a remote nn.Module over RPC, useful for parameter-server training styles.
from torchrl.collectors import MultiSyncCollector, VanillaWeightUpdater

collector = MultiSyncCollector(
    create_env_fn=[lambda: GymEnv("HalfCheetah-v4")] * 4,
    policy=policy,
    frames_per_batch=1024,
    total_frames=1_000_000,
    weight_updater=VanillaWeightUpdater(),
)

for batch in collector:
    # ... learning update ...
    collector.update_policy_weights_()  # push new weights to all workers

Evaluator

Evaluator is a companion class that runs periodic evaluation rollouts (without exploration) in a separate process, reporting metrics without interrupting the main training loop.
from torchrl.collectors import Collector, Evaluator

evaluator = Evaluator(
    env=lambda: GymEnv("HalfCheetah-v4"),
    policy=policy,
    num_trajectories=5,   # number of evaluation rollouts
    max_steps=1000,
    on_result=lambda result: print(result),
)

collector = Collector(
    create_env_fn=lambda: GymEnv("HalfCheetah-v4"),
    policy=policy,
    frames_per_batch=1024,
    total_frames=1_000_000,
)

for i, batch in enumerate(collector):
    # ... training ...
    evaluator.trigger_eval(policy, step=i)

Profiling collector workers

ProfileConfig lets you attach a PyTorch profiler to one or more collector workers, saving trace files for performance analysis. Call collector.enable_profile() after construction to activate profiling.
from torchrl.collectors import MultiSyncCollector

collector = MultiSyncCollector(
    create_env_fn=[lambda: GymEnv("HalfCheetah-v4")] * 4,
    policy=policy,
    frames_per_batch=1024,
    total_frames=100_000,
)

# Enable profiling after construction.
collector.enable_profile(
    workers=[0, 1],
    num_rollouts=5,
    warmup_rollouts=2,
    save_path="./traces/worker_{worker_idx}.json",
)

Choosing the right collector

1

Single process, on-policy

Use Collector for PPO / A2C smoke tests or small-scale training.
2

Multi-worker, synchronous, on-policy

Use MultiSyncCollector when you need a large synchronized batch from many environments (PPO at scale).
3

Multi-worker, asynchronous, off-policy

Use MultiAsyncCollector for SAC / TD3 / DQN where workers keep running while the learner updates.
4

Distributed / Ray cluster

Use MultiAsyncCollector with RayWeightUpdater and configure workers via a RayReplayBuffer for fully distributed jobs.

Build docs developers (and LLMs) love