Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Scaling reinforcement learning beyond a single machine requires distributing the environment-policy interaction loop across multiple workers, nodes, or cloud VMs. TorchRL ships four distributed collection strategies that all present the same iterator interface as single-process collectors: MultiSyncCollector and MultiAsyncCollector for multi-process collection on one machine, and DistributedCollector (torch.distributed backend) and RayCollector (Ray backend) for true multi-node distribution. All of them inherit from BaseCollector, so switching from a local Collector to a distributed one requires only changing the constructor — the training loop stays identical.
All multi-process collector classes must be created inside a if __name__ == "__main__": guard on Windows and macOS. See the Python multiprocessing docs for background.

MultiSyncCollector

MultiSyncCollector (from torchrl.collectors) runs a configurable number of Collector workers on separate processes synchronously: it waits for every worker to finish its rollout before assembling the result and yielding it to the caller. This guarantees that each yielded batch contains exactly frames_per_batch frames and that all workers use the same policy version, making it ideal for on-policy algorithms.
+----------------------------------------------------------------------+
|            MultiSyncCollector               |                        |
|  Collector 1  |  Collector 2  |  Collector 3  |      Main            |
|  env1 | env2  |  env3 | env4  |  env5 | env6  |                      |
|  reset | reset|  reset | reset|  reset | reset|                      |
|      actor    |      actor    |      actor    |                      |
|  step | step  |  step | step  |  step | step  |                      |
|               yield batch ─────────────────────> collect, train      |
+----------------------------------------------------------------------+

Constructor

MultiSyncCollector(
    create_env_fn,          # list of env callables (one per worker)
    policy=None,
    *,
    policy_factory=None,
    frames_per_batch,
    total_frames=-1,
    device=None,
    storing_device=None,
    env_device=None,
    policy_device=None,
    max_frames_per_traj=None,
    init_random_frames=None,
    reset_at_each_iter=False,
    postproc=None,
    split_trajs=False,
    exploration_type=ExplorationType.RANDOM,
    cat_results="stack",
    update_at_each_batch=False,
    preemptive_threshold=1.0,
    num_threads=None,
    num_sub_threads=1,
    ...
)

Key Parameters

create_env_fn
list[Callable] or Callable
required
A list of callables each returning an EnvBase instance — one per worker process. All callables can be the same lambda for identical environments. Passing a single callable creates one worker.
policy
Callable
default:"None"
Policy executed at each step on every worker. Must accept a TensorDictBase. If None, a RandomPolicy is used. Mutually exclusive with policy_factory.
frames_per_batch
int
required
Total frames per yielded batch across all workers. Each worker collects ceil(frames_per_batch / num_workers) frames, and the results are assembled (stacked or concatenated) by the main process.
total_frames
int
default:"-1"
Total frames over the collector’s lifetime. Pass -1 for endless collection. The iterator stops once the accumulated frame count reaches or exceeds this value.
cat_results
str
default:"\"stack\""
How to combine results from multiple workers. "stack" creates a leading worker dimension; "cat" concatenates along the batch dimension. Use "cat" when workers collect different environments.
update_at_each_batch
bool
default:"False"
If True, update_policy_weights_() is called automatically before every collection cycle. Convenient for on-policy algorithms that always want fresh weights.
preemptive_threshold
float
default:"1.0"
Fraction (0 to 1) of workers that must finish before stragglers are interrupted. Values below 1.0 reduce tail latency at the cost of slightly fewer frames from slow workers.

Usage

from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncCollector

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])

    collector = MultiSyncCollector(
        create_env_fn=[env_maker, env_maker, env_maker],
        policy=policy,
        frames_per_batch=300,
        total_frames=30_000,
        max_frames_per_traj=50,
        device="cpu",
        storing_device="cpu",
        cat_results="stack",
    )

    for i, batch in enumerate(collector):
        # batch shape: (300,) — 100 frames per worker, stacked
        loss = train_step(batch)
        collector.update_policy_weights_(policy)

    collector.shutdown()
MultiSyncCollector is the recommended choice for on-policy algorithms like PPO. Every yielded batch is collected with the exact same policy version, which satisfies the on-policy data requirement.

MultiAsyncCollector

MultiAsyncCollector (from torchrl.collectors) also fans out collection to multiple worker processes, but unlike MultiSyncCollector it does not wait for all workers to finish. Workers deliver rollouts on a first-ready basis, and the main process yields each one as it arrives. This maximizes hardware utilization and is ideal for off-policy algorithms where data staleness is acceptable.
+----------------------------------------------------------------------+
|          MultiAsyncCollector                |       Main             |
|  Collector 1  |  Collector 2  |  Collector 3  |                      |
|  yield batch 1|               |               | collect, train       |
|               |               | yield batch 2 | collect, train       |
|               | yield batch 3 |               | collect, train       |
+----------------------------------------------------------------------+

Constructor

MultiAsyncCollector shares the same constructor signature as MultiSyncCollector. The key behavioral difference is the non-blocking collection loop: workers run continuously in parallel and the main process retrieves batches as they arrive.
from torchrl.collectors import MultiAsyncCollector

if __name__ == "__main__":
    collector = MultiAsyncCollector(
        create_env_fn=[env_maker, env_maker],
        policy=policy,
        frames_per_batch=200,
        total_frames=20_000,
        device="cpu",
        storing_device="cpu",
        cat_results="stack",
    )

    for batch in collector:
        # Batches arrive from whichever worker finishes first
        replay_buffer.extend(batch)
        if len(replay_buffer) > min_size:
            train_step()

    collector.shutdown()
Because workers keep collecting in between next() calls, MultiAsyncCollector is not safe for on-policy algorithms — the policy used to collect a batch may differ from the current training policy. Use MultiSyncCollector or call update_policy_weights_() frequently.

DistributedCollector

DistributedCollector (from torchrl.collectors.distributed) scales collection across multiple machines using the torch.distributed backend (gloo, nccl, mpi, or ucc). Each remote node runs a Collector, MultiSyncCollector, or MultiAsyncCollector instance determined by the collector_class argument. The main process coordinates nodes through a TCPStore and assembles the results.

Constructor

from torchrl.collectors.distributed import DistributedCollector

DistributedCollector(
    create_env_fn,              # list of env callables — one per remote node
    policy=None,
    *,
    policy_factory=None,
    frames_per_batch,
    total_frames=-1,
    device=None,
    storing_device=None,
    env_device=None,
    policy_device=None,
    max_frames_per_traj=-1,
    collector_class=Collector,  # class or "single"/"sync"/"async"
    collector_kwargs=None,
    num_workers_per_collector=1,
    sync=False,
    backend="gloo",
    launcher="submitit",
    slurm_kwargs=None,
    tcp_port=None,
    update_after_each_batch=False,
    max_weight_update_interval=-1,
    weight_sync_schemes=None,
    weight_recv_schemes=None,
    ...
)

Key Parameters

create_env_fn
list[Callable]
required
One callable per remote node; each callable returns an EnvBase. For collector_class=MultiSyncCollector, each node creates num_workers_per_collector sub-processes internally.
frames_per_batch
int
required
Frames collected per remote node per iteration. With sync=True, the main process assembles all node results into one batch of size frames_per_batch * num_nodes.
collector_class
type or str
default:"Collector"
Class instantiated on each remote node. Accepts Collector, MultiSyncCollector, MultiAsyncCollector, or the shorthand strings "single", "sync", "async".
num_workers_per_collector
int
default:"1"
When collector_class is a multi-process class (MultiSyncCollector / MultiAsyncCollector), this sets the number of sub-workers on each remote node.
sync
bool
default:"False"
If True, the main process waits for all nodes and yields their combined result as a single stacked TensorDict. If False, each node’s result is yielded as it arrives (first-ready, first-served).
backend
str
default:"\"gloo\""
torch.distributed backend for weight synchronization. One of "gloo", "mpi", "nccl", or "ucc". Use "nccl" for GPU-to-GPU transfers.
launcher
str
default:"\"submitit\""
How remote processes are launched. "submitit" submits SLURM jobs (requires the submitit package) and supports multi-node clusters. "mp" uses Python multiprocessing on a single machine. "submitit_delayed" defers launch for clusters that forbid spawning from existing jobs.
update_after_each_batch
bool
default:"False"
If True, automatically pushes updated policy weights to all (sync) or the contributing (async) remote nodes after each collected batch.
weight_sync_schemes
dict[str, WeightSyncScheme]
default:"None"
Dictionary mapping model identifiers (e.g. "policy") to WeightSyncScheme instances that control how weights flow from the main process to remote nodes. Defaults to DistributedWeightSyncScheme.

Usage

from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import Collector
from torchrl.collectors.distributed import DistributedCollector

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])

    # Two remote nodes, each running a single-env Collector, using multiprocessing launcher
    collector = DistributedCollector(
        create_env_fn=[env_maker, env_maker],
        policy=policy,
        frames_per_batch=200,
        total_frames=20_000,
        collector_class=Collector,
        backend="gloo",
        launcher="mp",          # single-machine multiprocessing
        sync=True,
    )

    for batch in collector:
        loss = train_step(batch)
        collector.update_policy_weights_(policy)

    collector.shutdown()
DistributedCollector with launcher="submitit" requires the submitit package (pip install submitit) and a running SLURM scheduler. For single-machine multi-node simulation, launcher="mp" is sufficient.

RPCCollector

RPCCollector (from torchrl.collectors.distributed) uses PyTorch’s RPC framework (torch.distributed.rpc) to coordinate remote workers. Unlike DistributedCollector, which uses the TCPStore + scatter pattern, RPCCollector makes direct RPC calls on TensorDictModule remote references, enabling fine-grained remote method invocation and tight integration with TorchRL’s @accept_remote_rref_udf_invocation decorator.

Constructor

from torchrl.collectors.distributed import RPCCollector

RPCCollector(
    create_env_fn,
    policy=None,
    *,
    policy_factory=None,
    frames_per_batch,
    total_frames=-1,
    device=None,
    collector_class=Collector,
    collector_kwargs=None,
    num_workers_per_collector=1,
    sync=False,
    backend="gloo",
    launcher="submitit",
    tcp_port=None,
    visible_devices=None,
    tensorpipe_options=None,
    update_after_each_batch=False,
    max_weight_update_interval=-1,
    weight_sync_schemes=None,
    weight_recv_schemes=None,
    trajs_per_batch=None,
    ...
)

Key Parameters

visible_devices
list[int or torch.device]
default:"None"
A list of device identifiers (one per remote node) indicating which device is used to transfer data back to the main process. Needed when mixing CPU and GPU nodes.
tensorpipe_options
dict
default:"None"
Keyword arguments forwarded to torch.distributed.rpc.TensorPipeRpcBackendOptions. Use this to tune transport-level settings such as num_worker_threads.
backend
str
default:"\"gloo\""
torch.distributed backend used for weight synchronization alongside RPC. Usually "gloo" (CPU) or "nccl" (GPU).

Setup requirements

RPC initialization requires a process group to be set up before RPCCollector is created. The launcher handles this automatically, but for manual setups:
import os
import torch.distributed as dist
import torch.distributed.rpc as rpc

# Must be called before RPCCollector
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
rpc.init_rpc(
    f"COLLECTOR_NODE_{rank}",
    rank=rank,
    world_size=world_size,
    backend=rpc.BackendType.TENSORPIPE,
)

RayCollector

RayCollector (from torchrl.collectors.distributed) uses Ray to distribute collection across a Ray cluster. Each remote collector runs inside a Ray actor, and RayCollector coordinates them with synchronous or asynchronous scheduling. Compared to DistributedCollector, RayCollector requires no manual process-group initialization and supports heterogeneous clusters where workers have different resource requirements.

Constructor

from torchrl.collectors.distributed import RayCollector

RayCollector(
    create_env_fn,              # list of env callables — one per remote actor
    policy=None,
    *,
    policy_factory=None,
    frames_per_batch,
    total_frames=-1,
    device=None,
    collector_class=Collector,
    collector_kwargs=None,
    num_workers_per_collector=1,
    sync=False,
    ray_init_config=None,
    remote_configs=None,
    num_collectors=None,
    update_after_each_batch=False,
    max_weight_update_interval=-1,
    replay_buffer=None,
    weight_sync_schemes=None,
    weight_recv_schemes=None,
    trajs_per_batch=None,
    ...
)

Key Parameters

create_env_fn
list[Callable] or Callable
required
List of callables, each creating an EnvBase instance. Length determines the number of remote collector actors unless num_collectors is provided.
num_collectors
int
default:"None"
Explicit number of remote collector actors to create. When set, create_env_fn, collector_kwargs, and remote_configs are broadcast to all num_collectors actors if they are not already lists.
sync
bool
default:"False"
If True, all actors must finish their rollout before the main process assembles and yields the combined TensorDict. If False, each actor’s result is yielded first-ready.
ray_init_config
dict
default:"None"
Kwargs forwarded to ray.init(). If None, Ray auto-detects an existing cluster or starts a local one. Set address="auto" to connect to an existing cluster.
remote_configs
dict or list[dict]
default:"None"
Resource specifications for ray.remote() — controls CPU, GPU, and memory per actor. Defaults to {"num_cpus": 1, "num_gpus": 0.2, "memory": 2 * 1024**3}. A single dict is broadcast to all actors; a list assigns per-actor resources.
replay_buffer
RayReplayBuffer
default:"None"
When provided, remote actors write directly to this RayReplayBuffer instead of returning data to the main process. Must be a RayReplayBuffer — regular ReplayBuffer instances cannot be shared across Ray actor boundaries.
weight_sync_schemes
dict[str, WeightSyncScheme]
default:"None"
Mapping from model IDs to WeightSyncScheme instances for pushing weights to remote actors. Defaults to {"policy": RayWeightSyncScheme()}.

Usage

from torch import nn
from tensordict.nn import TensorDictModule
from torchrl.envs.libs.gym import GymEnv
from torchrl.collectors import Collector
from torchrl.collectors.distributed import RayCollector

env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])

collector = RayCollector(
    create_env_fn=[env_maker, env_maker, env_maker],
    policy=policy,
    collector_class=Collector,
    frames_per_batch=200,
    total_frames=20_000,
    max_frames_per_traj=50,
    collector_kwargs={"device": "cpu", "storing_device": "cpu"},
    num_collectors=3,
    sync=False,                     # async: yield each actor's result as ready
    remote_configs={
        "num_cpus": 2,
        "num_gpus": 0,
        "memory": 1 * 1024**3,
    },
)

for batch in collector:
    replay_buffer.extend(batch)
    if len(replay_buffer) > min_size:
        train_step()
        collector.update_policy_weights_(policy)

collector.shutdown()
RayCollector supports heterogeneous clusters natively. Pass remote_configs as a list to assign different GPU/CPU/memory budgets to each actor — useful when some nodes are more powerful than others.

Choosing a Collector

ScenarioRecommended Collector
On-policy, single envCollector
On-policy, multiple envsMultiSyncCollector
Off-policy, multiple envsMultiAsyncCollector
Off-policy, async fill of replay bufferAsyncCollector or MultiAsyncCollector

Weight Update Strategies

All distributed collectors push fresh policy weights to remote workers through update_policy_weights_(). The underlying mechanism is controlled by the weight_sync_schemes argument.

VanillaWeightUpdater

The simplest updater — performs an in-place weight copy on each worker’s policy. Suitable when workers share memory (e.g., MultiSyncCollector on one machine). The constructor requires policy_weights, a locked TensorDict of the policy’s parameters. Use the from_policy classmethod for a convenient one-liner.
VanillaWeightUpdater (and the entire WeightUpdaterBase hierarchy) is deprecated. Use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes for new code.
from tensordict import TensorDict
from torchrl.collectors import MultiSyncCollector, VanillaWeightUpdater

policy_weights = TensorDict.from_module(policy).lock_()

collector = MultiSyncCollector(
    [env_maker] * 4,
    policy=policy,
    frames_per_batch=400,
    weight_updater=VanillaWeightUpdater(policy_weights=policy_weights),
)
collector.update_policy_weights_(policy)

DistributedWeightSyncScheme (torch.distributed)

Used by DistributedCollector by default. Sends weights over the torch.distributed process group using isend/irecv for non-blocking transfer.
from torchrl.weight_update import DistributedWeightSyncScheme
from torchrl.collectors.distributed import DistributedCollector

collector = DistributedCollector(
    create_env_fn=[env_maker, env_maker],
    policy=policy,
    frames_per_batch=200,
    weight_sync_schemes={"policy": DistributedWeightSyncScheme()},
    backend="gloo",
    launcher="mp",
)

RayWeightUpdater / RayWeightSyncScheme

Used by RayCollector by default. Transfers weights through Ray’s object store, enabling zero-copy GPU-to-GPU transfer across nodes.
from torchrl.collectors.distributed import RayCollector

# Default weight sync is already Ray-based; override only for custom behavior:
collector = RayCollector(
    create_env_fn=[env_maker],
    policy=policy,
    frames_per_batch=200,
    # weight_sync_schemes defaults to {"policy": RayWeightSyncScheme()}
)
collector.update_policy_weights_(policy)

Multi-level hierarchies

Weight updates cascade automatically in hierarchical setups. A DistributedCollector with MultiSyncCollector sub-workers will push weights from the main process → distributed node → sub-workers using the appropriate scheme at each level.
# Main process → DistributedCollector node → MultiSyncCollector workers
collector = DistributedCollector(
    create_env_fn=[env_maker, env_maker],
    policy=policy,
    frames_per_batch=200,
    collector_class=MultiSyncCollector,     # each node runs a MultiSyncCollector
    num_workers_per_collector=2,
    backend="gloo",
    launcher="mp",
)
# update_policy_weights_ cascades through all levels automatically
collector.update_policy_weights_(policy)

Setup: torch.distributed and RPC

import os
import torch.distributed as dist

# Environment variables must be set before init_process_group
os.environ["MASTER_ADDR"] = "10.0.0.1"   # rank-0 machine IP
os.environ["MASTER_PORT"] = "29500"

dist.init_process_group(
    backend="gloo",       # or "nccl" for GPU
    init_method="env://",
    rank=rank,
    world_size=world_size,
)
# DistributedCollector handles the rest via its launcher

Build docs developers (and LLMs) love