Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL collectors are the central engine that drives the policy-environment interaction loop. A collector steps an environment with a policy, stacks the per-step TensorDict results into batches of a requested size, and yields them to your training loop as a standard Python iterator. All collector classes descend from BaseCollector, which implements the shared iteration protocol, trajectory assembly, weight-update machinery, and profiling hooks. For single-machine workloads the two primary classes are Collector (in-process, synchronous) and AsyncCollector (policy run on a background process). AsyncBatchedCollector collects from many environments asynchronously in a single call, while Evaluator provides a dedicated, callback-driven evaluation harness that can run in a thread, a process, or a Ray actor.
All multi-process collector classes (AsyncCollector, MultiSyncCollector, MultiAsyncCollector) must be created inside a if __name__ == "__main__": guard on Windows and macOS due to Python’s multiprocessing requirements.

BaseCollector

BaseCollector (from torchrl.collectors) is the abstract base class for every collector in TorchRL. It extends torch.utils.data.IterableDataset, so collectors can be passed directly to a DataLoader. Concrete subclasses implement iterator(), shutdown(), set_seed(), and state_dict().

Common interface

MemberDescription
__iter__()Yields TensorDictBase batches until total_frames is reached.
next()Single-step variant; returns None when exhausted.
start()Begins asynchronous background collection (requires a replay_buffer).
pause()Context manager that temporarily pauses a running async collection.
async_shutdown()Stops a collector started via start().
shutdown()Stops workers and closes environments; call when done iterating.
update_policy_weights_()Pushes fresh policy weights to workers.
set_seed(seed, static_seed)Propagates a seed to all sub-environments and returns the next seed.
state_dict()Returns an OrderedDict of collector state for checkpointing.
pre_collect_hookSettable property: a Callable[[], None] invoked before each rollout.
post_collect_hookSettable property: a Callable[[TensorDictBase], None] invoked after each rollout.

Trajectory batching

When trajs_per_batch is set, the collector switches from fixed-frame batches to complete-episode batches. Each yield has shape (trajs_per_batch, max_traj_len) with zero-padding and a ("collector", "mask") validity flag (or flat, unpadded batches with traj_format="cat").
When using a replay buffer with trajs_per_batch, complete trajectories are written to the buffer as flat 1-D sequences — exactly the layout expected by SliceSampler(end_key=("next", "done")). The collector yields None on each write.

Collector

Collector (from torchrl.collectors) is the standard single-process, synchronous collector. It creates one environment instance, runs the policy in-process at every step, and yields TensorDict batches to the caller. This is the recommended starting point for most online RL workflows.
from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torchrl.collectors import Collector
from torch import nn

env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])

collector = Collector(
    create_env_fn=env_maker,
    policy=policy,
    frames_per_batch=200,
    total_frames=10_000,
    max_frames_per_traj=50,
    device="cpu",
)

for batch in collector:
    # batch: TensorDict with keys action, observation, next, collector, done, ...
    loss = train_step(batch)
    collector.update_policy_weights_()

collector.shutdown()

Constructor

Collector(
    create_env_fn,
    policy=None,
    *,
    policy_factory=None,
    frames_per_batch,
    total_frames=-1,
    device=None,
    storing_device=None,
    env_device=None,
    policy_device=None,
    create_env_kwargs=None,
    max_frames_per_traj=None,
    init_random_frames=None,
    reset_at_each_iter=False,
    postproc=None,
    split_trajs=False,
    trajs_per_batch=None,
    traj_format=None,
    exploration_type=ExplorationType.RANDOM,
    return_same_td=False,
    replay_buffer=None,
    compile_policy=False,
    cudagraph_policy=False,
    ...
)

Parameters

create_env_fn
Callable or EnvBase
required
A callable that returns an EnvBase instance, or the env itself. The callable form is preferred because it allows the collector to recreate the environment if needed (e.g., after seeding or reset).
policy
Callable
default:"None"
Policy to execute at each step. Must accept a TensorDictBase and write actions back into it. Subclasses of TensorDictModuleBase are recommended. If None, a RandomPolicy sampling from env.action_spec is used automatically. Mutually exclusive with policy_factory.
frames_per_batch
int
required
Number of environment steps per yielded batch. The collector accumulates exactly this many frames before yielding, so batches are always full.
total_frames
int
default:"-1"
Total frames to collect over the lifetime of the collector. Iteration stops once this threshold is reached. Pass -1 for an endless collector. Must be divisible by frames_per_batch when finite.
device
str or torch.device
default:"None"
Convenience device applied to storing_device, policy_device, and env_device when they are not explicitly set. Setting a single device is the simplest way to run everything on one GPU.
storing_device
str or torch.device
default:"None"
Device on which the output TensorDict is stored. For long trajectories or GPU environments, you may want to store on CPU to avoid GPU memory pressure.
max_frames_per_traj
int
default:"None"
Maximum steps per trajectory before the environment is forcibly reset. Tracked per parallel environment independently. Negative values disable the limit.
init_random_frames
int
default:"None"
Number of frames at the start of collection where the policy is ignored and random actions are taken instead. Rounded up to the nearest frames_per_batch. Useful for initializing replay buffers before training begins.
reset_at_each_iter
bool
default:"False"
If True, all environments are reset at the start of every batch. This guarantees that no trajectory spans across batches, at the cost of discarding partial rollouts.
postproc
Callable
default:"None"
A post-processing transform applied to each batch before it is yielded. Accepts torchrl.envs.Transform instances or MultiStep transforms. Not applied when writing directly to a replay buffer with extend_buffer=False.
split_trajs
bool
default:"False"
If True, each batch is split along trajectory boundaries using split_trajectories() and padded to a uniform length within that batch. Note that trajectories spanning two batches remain split. For whole-trajectory collection see trajs_per_batch.
trajs_per_batch
int
default:"None"
When set, the collector yields complete-episode batches of exactly this many trajectories instead of fixed-frame batches. Partial episodes are held internally until they terminate. Pairs naturally with a replay buffer and SliceSampler.
exploration_type
ExplorationType
default:"ExplorationType.RANDOM"
Exploration mode for data collection. One of DETERMINISTIC, RANDOM, MODE, or MEAN from torchrl.envs.utils.ExplorationType.
replay_buffer
ReplayBuffer
default:"None"
When provided, the collector writes directly to this buffer instead of yielding batches. Combined with trajs_per_batch, complete trajectories are written as flat 1-D sequences compatible with SliceSampler.
compile_policy
bool or dict
default:"False"
If True, wraps the policy with torch.compile(). A dict of kwargs is also accepted and forwarded to torch.compile.
cudagraph_policy
bool or dict
default:"False"
If True, wraps the policy in CudaGraphModule for CUDA graph capture. A dict of kwargs can be passed for advanced configuration.

Output TensorDict

Each yielded TensorDict has the following top-level keys (depending on env spec):
KeyShapeDescription
"action"(frames_per_batch, *action_shape)Actions selected by the policy
"observation"(frames_per_batch, *obs_shape)Observations at step t
"done"(frames_per_batch, 1)Episode termination flag at step t
"next"nested TensorDictobservation, reward, done at step t+1
("collector", "traj_ids")(frames_per_batch,)Unique trajectory identifier per step

AsyncCollector

AsyncCollector (from torchrl.collectors) runs a single collector on a background process. It is a thin wrapper around MultiAsyncCollector with num_workers=1. Because data collection continues in the background between calls to next(), AsyncCollector is well-suited for offline RL and settings where the training policy differs from the behavior policy.
For online RL where the training and behavior policies are the same, prefer Collector — it is simpler and has lower overhead. AsyncCollector collects data with stale weights by default.
from torchrl.collectors import AsyncCollector
from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])

    collector = AsyncCollector(
        create_env_fn=env_maker,
        policy=policy,
        frames_per_batch=200,
        total_frames=10_000,
        device="cpu",
    )

    for batch in collector:
        # Training step while the next batch is being collected in the background
        loss = train_step(batch)
        collector.update_policy_weights_()

    collector.shutdown()
AsyncCollector accepts the same parameters as Collector plus the multi-process arguments shared with MultiAsyncCollector (see the Distributed Collectors page for the full MultiCollector parameter list).

AsyncBatchedCollector

AsyncBatchedCollector (from torchrl.collectors) collects from multiple environments asynchronously within a single process, yielding complete trajectories through the yield_completed_trajectories flag. This is the equivalent of trajs_per_batch for the async multi-env setting.
from torchrl.collectors import AsyncBatchedCollector
from torchrl.envs import SerialEnv, GymEnv

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1")

    collector = AsyncBatchedCollector(
        create_env_fn=[env_maker] * 4,
        policy=policy,
        frames_per_batch=400,
        total_frames=20_000,
        yield_completed_trajectories=True,
    )

    for batch in collector:
        train(batch)

    collector.shutdown()

Evaluator

Evaluator (from torchrl.collectors) provides a unified synchronous and asynchronous evaluation harness. Internally it uses a Collector with trajs_per_batch to collect complete episodes. Three backends are supported:
  • "thread" (default) — low-overhead daemon thread; suitable for GPU-bound evaluation.
  • "process" — child process with full CUDA isolation; requires env to be a callable and policy_factory to be provided.
  • "ray" — Ray actor; suitable for distributed multi-node setups.

Constructor

Evaluator(
    env,
    policy=None,
    *,
    policy_factory=None,
    num_trajectories=10,
    max_steps=None,
    frames_per_batch=None,
    collector_cls=None,
    collector_kwargs=None,
    weight_sync_schemes=None,
    log_prefix="eval",
    reward_keys=("next", "reward"),
    done_keys=("next", "done"),
    device=None,
    exploration_type=ExplorationType.DETERMINISTIC,
    metrics_fn=None,
    dump_video=True,
    on_result=None,
    busy_policy="skip",
    backend="thread",
    init_fn=None,
    num_gpus=1,
    ray_kwargs=None,
)

Parameters

env
EnvBase or Callable
required
Evaluation environment or a factory that creates one. The callable form is required for "process" and "ray" backends, and recommended for "thread" when combined with policy_factory (defers construction to the worker).
policy
Callable
default:"None"
Evaluation policy. Mutually exclusive with policy_factory. For "process" and "ray" backends, policy_factory must be used instead because the policy is constructed inside a subprocess.
num_trajectories
int
default:"10"
Number of complete episodes to collect per evaluation round.
max_steps
int
default:"None"
Maximum steps per episode, passed as max_frames_per_traj to the internal collector. None means no step limit.
on_result
Callable
default:"None"
Callback invoked after each completed evaluation with a flat tensordict of metrics. Useful for logging without blocking the training loop.
busy_policy
str
default:"\"skip\""
Behaviour when trigger_eval() is called while an evaluation is already running. "skip" silently drops the new request (default). "error" raises immediately. "queue" enqueues the new request (stores a copy of weights — can be memory-intensive for large models).
backend
str
default:"\"thread\""
Execution backend. One of "thread", "process", or "ray". The "process" backend provides full CUDA context isolation and requires env to be a callable and policy_factory to be provided.
exploration_type
ExplorationType
default:"ExplorationType.DETERMINISTIC"
Exploration mode during evaluation. Defaults to deterministic (greedy) policy evaluation.

Methods

MethodSignatureDescription
evaluate(weights, step) -> dictBlocking evaluation; returns metrics dict immediately.
trigger_eval(weights, step) -> boolNon-blocking; schedules eval in background, returns True if accepted.
poll() -> dict or NoneReturns result if ready, None otherwise.
wait() -> dictBlocks until the current evaluation completes.
pendingproperty boolTrue while an async evaluation is in progress.
shutdown() -> NoneStops background thread/process and frees resources.

Usage examples

from torchrl.collectors import Evaluator

evaluator = Evaluator(
    make_eval_env,
    policy=eval_policy,
    num_trajectories=10,
    max_steps=500,
    log_prefix="eval",
)

for step in range(num_steps):
    train(...)
    metrics = evaluator.evaluate(train_policy, step=step)
    print(f"eval/reward: {metrics['eval/reward']:.2f}")

evaluator.shutdown()

ProfileConfig

ProfileConfig (from torchrl.collectors) is a dataclass that configures PyTorch profiler integration for any collector. Pass it to collector.enable_profile(...) to start tracing selected workers.
from torchrl.collectors import Collector, ProfileConfig

collector = Collector(env_maker, policy, frames_per_batch=200, total_frames=2000)
collector.enable_profile(
    workers=[0],
    num_rollouts=5,
    warmup_rollouts=2,
    save_path="./traces/worker_{worker_idx}.json",
)

for batch in collector:
    process(batch)
# Trace files written to ./traces/worker_0.json
AttributeTypeDefaultDescription
workerslist[int][0]Worker indices to profile (ignored for single-process Collector).
num_rolloutsint3Total rollouts to profile (including warmup). Profiling stops afterwards.
warmup_rolloutsint1Rollouts to skip before actual profiling begins (allows JIT/compile warmup).
save_pathstr or PathNoneTrace output path. Supports {worker_idx} placeholder.
activitieslist[str]["cpu", "cuda"]Profiler activities.
record_shapesboolTrueRecord tensor shapes in the trace.
profile_memoryboolFalseTrack memory allocations.
with_stackboolTrueCapture Python/C++ stack traces.
with_flopsboolFalseEstimate FLOP counts.
on_trace_readyCallableNoneCustom handler when a trace interval completes.

Weight Synchronization

All collectors expose update_policy_weights_() to push fresh weights from the training process to the collection workers. For Collector this is a no-op unless a weight_updater or weight_sync_schemes is configured.
# After a gradient step, push the new parameters:
collector.update_policy_weights_(policy_or_weights=policy)

# Or pass a TensorDict of weights explicitly:
weights = TensorDict.from_module(policy).data
collector.update_policy_weights_(policy_or_weights=weights)

# Update only specific workers:
collector.update_policy_weights_(policy_or_weights=policy, worker_ids=[0, 2])
The weight_sync_schemes constructor argument accepts a dict[str, WeightSyncScheme] mapping model identifiers to sync strategies. This enables flexible setups such as shared-memory transfer for multi-process collectors or RPC-based updates in distributed hierarchies.
Set update_at_each_batch=True on multi-process collectors to automatically call update_policy_weights_() before (sync) or after (async) every batch — convenient for on-policy algorithms where weights must stay fresh.

VanillaWeightUpdater

VanillaWeightUpdater performs a simple in-place weight copy on the policy. It is imported from torchrl.collectors. The constructor requires policy_weights, a locked TensorDict of the policy’s parameters (use TensorDict.from_module(policy).lock_() to create one). Alternatively, use the from_policy classmethod to construct it from a policy directly.
VanillaWeightUpdater (and the entire WeightUpdaterBase hierarchy) is deprecated. Use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes for new code.
from tensordict import TensorDict
from torchrl.collectors import Collector, VanillaWeightUpdater

policy_weights = TensorDict.from_module(policy).lock_()

collector = Collector(
    env_maker,
    policy=policy,
    frames_per_batch=200,
    weight_updater=VanillaWeightUpdater(policy_weights=policy_weights),
)
collector.update_policy_weights_(policy)

Build docs developers (and LLMs) love