TorchRL collectors are the central engine that drives the policy-environment interaction loop. A collector steps an environment with a policy, stacks the per-stepDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
TensorDict results into batches of a requested size, and yields them to your training loop as a standard Python iterator. All collector classes descend from BaseCollector, which implements the shared iteration protocol, trajectory assembly, weight-update machinery, and profiling hooks. For single-machine workloads the two primary classes are Collector (in-process, synchronous) and AsyncCollector (policy run on a background process). AsyncBatchedCollector collects from many environments asynchronously in a single call, while Evaluator provides a dedicated, callback-driven evaluation harness that can run in a thread, a process, or a Ray actor.
All multi-process collector classes (
AsyncCollector, MultiSyncCollector, MultiAsyncCollector) must be created inside a if __name__ == "__main__": guard on Windows and macOS due to Python’s multiprocessing requirements.BaseCollector
BaseCollector (from torchrl.collectors) is the abstract base class for every collector in TorchRL. It extends torch.utils.data.IterableDataset, so collectors can be passed directly to a DataLoader. Concrete subclasses implement iterator(), shutdown(), set_seed(), and state_dict().
Common interface
| Member | Description |
|---|---|
__iter__() | Yields TensorDictBase batches until total_frames is reached. |
next() | Single-step variant; returns None when exhausted. |
start() | Begins asynchronous background collection (requires a replay_buffer). |
pause() | Context manager that temporarily pauses a running async collection. |
async_shutdown() | Stops a collector started via start(). |
shutdown() | Stops workers and closes environments; call when done iterating. |
update_policy_weights_() | Pushes fresh policy weights to workers. |
set_seed(seed, static_seed) | Propagates a seed to all sub-environments and returns the next seed. |
state_dict() | Returns an OrderedDict of collector state for checkpointing. |
pre_collect_hook | Settable property: a Callable[[], None] invoked before each rollout. |
post_collect_hook | Settable property: a Callable[[TensorDictBase], None] invoked after each rollout. |
Trajectory batching
Whentrajs_per_batch is set, the collector switches from fixed-frame batches to complete-episode batches. Each yield has shape (trajs_per_batch, max_traj_len) with zero-padding and a ("collector", "mask") validity flag (or flat, unpadded batches with traj_format="cat").
Collector
Collector (from torchrl.collectors) is the standard single-process, synchronous collector. It creates one environment instance, runs the policy in-process at every step, and yields TensorDict batches to the caller. This is the recommended starting point for most online RL workflows.
Constructor
Parameters
A callable that returns an
EnvBase instance, or the env itself. The callable form is preferred because it allows the collector to recreate the environment if needed (e.g., after seeding or reset).Policy to execute at each step. Must accept a
TensorDictBase and write actions back into it. Subclasses of TensorDictModuleBase are recommended. If None, a RandomPolicy sampling from env.action_spec is used automatically. Mutually exclusive with policy_factory.Number of environment steps per yielded batch. The collector accumulates exactly this many frames before yielding, so batches are always full.
Total frames to collect over the lifetime of the collector. Iteration stops once this threshold is reached. Pass
-1 for an endless collector. Must be divisible by frames_per_batch when finite.Convenience device applied to
storing_device, policy_device, and env_device when they are not explicitly set. Setting a single device is the simplest way to run everything on one GPU.Device on which the output
TensorDict is stored. For long trajectories or GPU environments, you may want to store on CPU to avoid GPU memory pressure.Maximum steps per trajectory before the environment is forcibly reset. Tracked per parallel environment independently. Negative values disable the limit.
Number of frames at the start of collection where the policy is ignored and random actions are taken instead. Rounded up to the nearest
frames_per_batch. Useful for initializing replay buffers before training begins.If
True, all environments are reset at the start of every batch. This guarantees that no trajectory spans across batches, at the cost of discarding partial rollouts.A post-processing transform applied to each batch before it is yielded. Accepts
torchrl.envs.Transform instances or MultiStep transforms. Not applied when writing directly to a replay buffer with extend_buffer=False.If
True, each batch is split along trajectory boundaries using split_trajectories() and padded to a uniform length within that batch. Note that trajectories spanning two batches remain split. For whole-trajectory collection see trajs_per_batch.When set, the collector yields complete-episode batches of exactly this many trajectories instead of fixed-frame batches. Partial episodes are held internally until they terminate. Pairs naturally with a replay buffer and
SliceSampler.Exploration mode for data collection. One of
DETERMINISTIC, RANDOM, MODE, or MEAN from torchrl.envs.utils.ExplorationType.When provided, the collector writes directly to this buffer instead of yielding batches. Combined with
trajs_per_batch, complete trajectories are written as flat 1-D sequences compatible with SliceSampler.If
True, wraps the policy with torch.compile(). A dict of kwargs is also accepted and forwarded to torch.compile.If
True, wraps the policy in CudaGraphModule for CUDA graph capture. A dict of kwargs can be passed for advanced configuration.Output TensorDict
Each yieldedTensorDict has the following top-level keys (depending on env spec):
| Key | Shape | Description |
|---|---|---|
"action" | (frames_per_batch, *action_shape) | Actions selected by the policy |
"observation" | (frames_per_batch, *obs_shape) | Observations at step t |
"done" | (frames_per_batch, 1) | Episode termination flag at step t |
"next" | nested TensorDict | observation, reward, done at step t+1 |
("collector", "traj_ids") | (frames_per_batch,) | Unique trajectory identifier per step |
AsyncCollector
AsyncCollector (from torchrl.collectors) runs a single collector on a background process. It is a thin wrapper around MultiAsyncCollector with num_workers=1. Because data collection continues in the background between calls to next(), AsyncCollector is well-suited for offline RL and settings where the training policy differs from the behavior policy.
AsyncCollector accepts the same parameters as Collector plus the multi-process arguments shared with MultiAsyncCollector (see the Distributed Collectors page for the full MultiCollector parameter list).
AsyncBatchedCollector
AsyncBatchedCollector (from torchrl.collectors) collects from multiple environments asynchronously within a single process, yielding complete trajectories through the yield_completed_trajectories flag. This is the equivalent of trajs_per_batch for the async multi-env setting.
Evaluator
Evaluator (from torchrl.collectors) provides a unified synchronous and asynchronous evaluation harness. Internally it uses a Collector with trajs_per_batch to collect complete episodes. Three backends are supported:
"thread"(default) — low-overhead daemon thread; suitable for GPU-bound evaluation."process"— child process with full CUDA isolation; requiresenvto be a callable andpolicy_factoryto be provided."ray"— Ray actor; suitable for distributed multi-node setups.
Constructor
Parameters
Evaluation environment or a factory that creates one. The callable form is required for
"process" and "ray" backends, and recommended for "thread" when combined with policy_factory (defers construction to the worker).Evaluation policy. Mutually exclusive with
policy_factory. For "process" and "ray" backends, policy_factory must be used instead because the policy is constructed inside a subprocess.Number of complete episodes to collect per evaluation round.
Maximum steps per episode, passed as
max_frames_per_traj to the internal collector. None means no step limit.Callback invoked after each completed evaluation with a flat tensordict of metrics. Useful for logging without blocking the training loop.
Behaviour when
trigger_eval() is called while an evaluation is already running. "skip" silently drops the new request (default). "error" raises immediately. "queue" enqueues the new request (stores a copy of weights — can be memory-intensive for large models).Execution backend. One of
"thread", "process", or "ray". The "process" backend provides full CUDA context isolation and requires env to be a callable and policy_factory to be provided.Exploration mode during evaluation. Defaults to deterministic (greedy) policy evaluation.
Methods
| Method | Signature | Description |
|---|---|---|
evaluate | (weights, step) -> dict | Blocking evaluation; returns metrics dict immediately. |
trigger_eval | (weights, step) -> bool | Non-blocking; schedules eval in background, returns True if accepted. |
poll | () -> dict or None | Returns result if ready, None otherwise. |
wait | () -> dict | Blocks until the current evaluation completes. |
pending | property bool | True while an async evaluation is in progress. |
shutdown | () -> None | Stops background thread/process and frees resources. |
Usage examples
- Synchronous
- Async with callback
- Process backend
ProfileConfig
ProfileConfig (from torchrl.collectors) is a dataclass that configures PyTorch profiler integration for any collector. Pass it to collector.enable_profile(...) to start tracing selected workers.
| Attribute | Type | Default | Description |
|---|---|---|---|
workers | list[int] | [0] | Worker indices to profile (ignored for single-process Collector). |
num_rollouts | int | 3 | Total rollouts to profile (including warmup). Profiling stops afterwards. |
warmup_rollouts | int | 1 | Rollouts to skip before actual profiling begins (allows JIT/compile warmup). |
save_path | str or Path | None | Trace output path. Supports {worker_idx} placeholder. |
activities | list[str] | ["cpu", "cuda"] | Profiler activities. |
record_shapes | bool | True | Record tensor shapes in the trace. |
profile_memory | bool | False | Track memory allocations. |
with_stack | bool | True | Capture Python/C++ stack traces. |
with_flops | bool | False | Estimate FLOP counts. |
on_trace_ready | Callable | None | Custom handler when a trace interval completes. |
Weight Synchronization
All collectors exposeupdate_policy_weights_() to push fresh weights from the training process to the collection workers. For Collector this is a no-op unless a weight_updater or weight_sync_schemes is configured.
weight_sync_schemes constructor argument accepts a dict[str, WeightSyncScheme] mapping model identifiers to sync strategies. This enables flexible setups such as shared-memory transfer for multi-process collectors or RPC-based updates in distributed hierarchies.
VanillaWeightUpdater
VanillaWeightUpdater performs a simple in-place weight copy on the policy. It is imported from torchrl.collectors. The constructor requires policy_weights, a locked TensorDict of the policy’s parameters (use TensorDict.from_module(policy).lock_() to create one). Alternatively, use the from_policy classmethod to construct it from a policy directly.