Data collection is the primary bottleneck in on-policy reinforcement learning: the policy must interact with the environment to generate experience before any gradient update can happen. TorchRL provides a hierarchy of collector classes — from a single-processDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
Collector to a multi-machine RayCollector — that all expose the same iterator interface. You can prototype on a laptop with Collector, then switch to MultiSyncCollector for a multi-core workstation or RayCollector for a cluster, with almost no changes to the training loop. This tutorial walks up the hierarchy, showing when and how to use each level.
Collector is the simplest option: policy and environment live in the same process. It is the right choice for debugging and for environments that run very fast.import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.collectors import Collector
from torchrl.envs import GymEnv
# A minimal linear policy for demonstration
policy = TensorDictModule(
nn.Linear(3, 1),
in_keys=["observation"],
out_keys=["action"],
)
collector = Collector(
create_env_fn=lambda: GymEnv("Pendulum-v1", device="cpu"),
policy=policy,
frames_per_batch=200,
total_frames=10_000,
device="cpu",
max_frames_per_traj=50, # reset after 50 steps
)
for data in collector:
# data is a TensorDict with shape (200,)
print(data.shape, data.keys())
break
collector.shutdown()
env = GymEnv("Pendulum-v1")
collector = Collector(create_env_fn=env, policy=policy, frames_per_batch=200, total_frames=2000)
Always call
collector.shutdown() when done. This closes the underlying environment and releases any shared memory. Using a context manager (with statement) is not currently supported — use an explicit try/finally block in production code.MultiSyncCollector spawns one subprocess per entry in create_env_fn. All workers collect data in parallel, but the main process waits for every worker to finish before yielding the next batch. This is ideal for on-policy algorithms (PPO, A2C) where the training step must see fresh data from all workers.from torchrl.collectors import MultiSyncCollector
if __name__ == "__main__": # required for multiprocessing on Windows/macOS
env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
collector = MultiSyncCollector(
create_env_fn=[env_maker, env_maker, env_maker, env_maker], # 4 workers
policy=policy,
frames_per_batch=400, # 100 frames per worker per batch
total_frames=40_000,
max_frames_per_traj=50,
device="cpu",
storing_device="cpu",
cat_results="stack", # stack results along a new dim (vs. "cat" to concatenate)
)
for i, data in enumerate(collector):
# data has shape (400,) — all workers' data concatenated or stacked
print(f"batch {i}: {data.shape}")
# After each batch, push updated weights to workers
collector.update_policy_weights_()
if i == 4:
break
collector.shutdown()
Python’s
multiprocessing requires the collector construction to be inside an if __name__ == "__main__": guard on Windows and macOS. On Linux (fork start method) this is not strictly required, but it remains good practice.MultiAsyncCollector is the asynchronous counterpart: workers never wait for each other. The main process receives batches as soon as any worker finishes collecting one. This maximises throughput at the cost of policy staleness — a worker may be running an older policy version when its batch arrives.MultiAsyncCollector suits off-policy algorithms (SAC, TD3, DQN) where data from a slightly stale policy is still usable.from torchrl.collectors import MultiAsyncCollector
if __name__ == "__main__":
collector = MultiAsyncCollector(
create_env_fn=[env_maker, env_maker, env_maker, env_maker],
policy=policy,
frames_per_batch=200,
total_frames=40_000,
max_frames_per_traj=50,
device="cpu",
storing_device="cpu",
cat_results="stack",
)
for i, data in enumerate(collector):
# Batches arrive from whichever worker finishes first
print(f"batch from worker, shape={data.shape}")
# Update only the worker that just delivered this batch
collector.update_policy_weights_()
if i == 10:
break
collector.shutdown()
MultiSyncCollector
Synchronous collection — all workers finish before the main process receives data.
- All data in a batch was collected with the same policy version.
- Lower throughput because fast workers idle while slow ones finish.
- Natural fit for on-policy algorithms (PPO, A2C) where fresh data is required.
- Simpler to reason about — each training step sees a clean, uniform batch.
MultiAsyncCollector
Asynchronous collection — the main process processes each batch as it arrives.
- Workers run continuously; no idle time.
- Batches may come from different policy versions (stale-data problem).
- Higher throughput for environments with variable episode lengths.
- Natural fit for off-policy algorithms (SAC, TD3, DQN) with a replay buffer.
All collector classes expose
update_policy_weights_() to push updated network weights from the training process into workers. For multiprocess collectors the transfer uses shared memory (on the same machine) or torch.distributed for cross-machine setups.# Basic usage: push current policy weights to all workers
collector.update_policy_weights_()
# Explicit: pass a policy module
collector.update_policy_weights_(policy)
# For async collectors: update only the worker that just returned a batch
# (handled automatically when you call update_policy_weights_() after each batch)
collector.update_policy_weights_()
For
MultiSyncCollector, calling update_policy_weights_() after every collected batch is the standard pattern for on-policy training — it ensures the next batch is always generated with the latest policy.RayCollector uses Ray to distribute collection across a Ray cluster. Each remote collector is a separate Ray actor; the main process coordinates them. The interface is identical to local collectors.from torchrl.collectors.distributed import RayCollector
from torchrl.collectors import Collector # used as the per-worker collector class
if __name__ == "__main__":
# Default: RayCollector auto-detects an existing cluster or starts a local Ray instance
collector = RayCollector(
create_env_fn=lambda: GymEnv("Pendulum-v1"),
policy=policy,
frames_per_batch=400,
total_frames=40_000,
num_collectors=4, # 4 remote Ray actors
sync=True, # True = synchronous, False = async (first-ready)
collector_class=Collector, # per-worker collector type
collector_kwargs={
"max_frames_per_traj": 50,
},
# Resource spec passed to ray.remote() for each actor
remote_configs={
"num_cpus": 1,
"num_gpus": 0,
"memory": 2 * 1024 ** 3,
},
update_after_each_batch=True, # auto-sync weights after every batch
)
for i, data in enumerate(collector):
print(f"batch {i}: {data.shape}")
if i == 5:
break
collector.shutdown()
collector = RayCollector(
create_env_fn=lambda: GymEnv("Pendulum-v1"),
policy=policy,
frames_per_batch=400,
total_frames=200_000,
num_collectors=16,
ray_init_config={
"address": "ray://head-node:10001", # address of the Ray head node
},
remote_configs={
"num_cpus": 2,
"num_gpus": 0.5,
"memory": 4 * 1024 ** 3,
},
)
RPCCollector uses torch.distributed.rpc as the transport instead of Ray. It is appropriate when you are already running in a torch.distributed context (e.g., on a SLURM cluster with PyTorch’s native launcher) and do not want to install Ray.# RPCCollector is intended to be run from a launch script (torchrun / submitit)
# The snippet below shows the collector construction; see the TorchRL examples
# for a full launch script.
from torchrl.collectors.distributed import RPCCollector
collector = RPCCollector(
create_env_fn=lambda: GymEnv("Pendulum-v1"),
policy=policy,
frames_per_batch=400,
total_frames=100_000,
num_collectors=4,
sync=True,
)
for data in collector:
collector.update_policy_weights_()
collector.shutdown()
RPCCollector requires torch.distributed.rpc to be initialised before instantiation. The torchrun launcher (or submitit for SLURM) handles this automatically. Refer to the TorchRL RPC example for the full multi-node setup.The example below shows how the collector hierarchy slots into a real training loop. Switching between
Collector, MultiSyncCollector, and RayCollector only requires changing the first few lines."""Multi-worker PPO training loop (sync variant)."""
from __future__ import annotations
import torch
import torch.nn as nn
import tqdm
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.collectors import MultiSyncCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import (
ClipTransform, DoubleToFloat, ExplorationType,
GymEnv,
RewardSum, StepCounter, TransformedEnv, VecNorm,
)
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
ENV_NAME = "HalfCheetah-v4"
N_WORKERS = 4
FRAMES_PER_BATCH = 2048 # total across all workers
MINI_BATCH_SIZE = 256
TOTAL_FRAMES = 1_000_000
PPO_EPOCHS = 10
def make_env():
env = GymEnv(ENV_NAME, device="cpu")
env = TransformedEnv(env)
env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(RewardSum())
env.append_transform(StepCounter())
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env
proof_env = make_env()
obs_dim = proof_env.observation_spec["observation"].shape[-1]
act_dim = proof_env.action_spec_unbatched.shape[-1]
policy_mlp = nn.Sequential(
MLP(obs_dim, activation_class=nn.Tanh, out_features=act_dim, num_cells=[64, 64]),
AddStateIndependentNormalScale(act_dim, scale_lb=1e-8),
)
actor = ProbabilisticActor(
TensorDictModule(policy_mlp, in_keys=["observation"], out_keys=["loc", "scale"]),
in_keys=["loc", "scale"],
spec=proof_env.full_action_spec_unbatched,
distribution_class=TanhNormal,
distribution_kwargs={
"low": proof_env.action_spec_unbatched.space.low,
"high": proof_env.action_spec_unbatched.space.high,
"tanh_loc": False,
},
return_log_prob=True,
default_interaction_type=ExplorationType.RANDOM,
)
critic = ValueOperator(
MLP(obs_dim, activation_class=nn.Tanh, out_features=1, num_cells=[64, 64]),
in_keys=["observation"],
)
adv_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic, average_gae=False)
loss_module = ClipPPOLoss(
actor_network=actor, critic_network=critic,
clip_epsilon=0.2, entropy_coeff=0.01, critic_coeff=0.5, normalize_advantage=True,
)
optim = group_optimizers(
torch.optim.Adam(actor.parameters(), lr=3e-4, eps=1e-5),
torch.optim.Adam(critic.parameters(), lr=3e-4, eps=1e-5),
)
if __name__ == "__main__":
# ── swap this line to scale: ──────────────────────────────────────────────
# Single process: Collector(make_env, actor, ...)
# Multi-process: MultiSyncCollector([make_env]*N_WORKERS, actor, ...)
# Multi-machine: RayCollector(make_env, actor, num_collectors=N, ...)
# ─────────────────────────────────────────────────────────────────────────
collector = MultiSyncCollector(
create_env_fn=[make_env] * N_WORKERS,
policy=actor,
frames_per_batch=FRAMES_PER_BATCH,
total_frames=TOTAL_FRAMES,
device="cpu",
storing_device="cpu",
)
data_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(FRAMES_PER_BATCH),
sampler=SamplerWithoutReplacement(),
batch_size=MINI_BATCH_SIZE,
)
pbar = tqdm.tqdm(total=TOTAL_FRAMES)
for data in collector:
pbar.update(data.numel())
for _ in range(PPO_EPOCHS):
with torch.no_grad():
data = adv_module(data)
data_buffer.extend(data.reshape(-1))
for batch in data_buffer:
optim.zero_grad(set_to_none=True)
loss = loss_module(batch)
(
loss["loss_objective"]
+ loss["loss_critic"]
+ loss["loss_entropy"]
).backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
optim.step()
# Synchronise updated weights with all worker processes
collector.update_policy_weights_()
ep_r = data["next", "episode_reward"][data["next", "done"]]
if len(ep_r):
pbar.set_description(f"reward={ep_r.mean().item():.2f}")
collector.shutdown()
Collector comparison
| Collector | Workers | Synchronisation | Best for |
|---|---|---|---|
Collector | 1 (same process) | n/a | Debugging, fast envs |
MultiSyncCollector | N (subprocesses) | All workers finish before main gets batch | On-policy (PPO, A2C) |
MultiAsyncCollector | N (subprocesses) | Main gets batches as they arrive | Off-policy (SAC, DQN) |
RayCollector | N (Ray actors) | Configurable (sync=True/False) | Multi-machine clusters |
RPCCollector | N (RPC workers) | Configurable (sync=True/False) | SLURM / torchrun setups |