Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL treats multi-agent reinforcement learning (MARL) as a natural extension of the single-agent TensorDict data model: agent-specific tensors are nested one level deeper under a group key such as "agents". The shape (*batch, n_agents, obs_dim) replaces the single-agent shape (*batch, obs_dim), and nested keys like ("agents", "observation") replace flat keys like "observation". Every collector, replay buffer, and loss module that works for single agents also works for multi-agent setups—you only need to tell each component which keys to read and write. This tutorial builds a full MAPPO/IPPO training loop on the VMAS vectorised multi-agent simulator, closely following the minimal example in the TorchRL repository.
1
Install dependencies
2
pip install torchrl tensordict torch vmas
3
VMAS (Vectorized Multi-Agent Simulator) is a GPU-accelerated simulator built on PyTorch. Install it with pip install vmas. PettingZoo is supported via PettingZooEnv—install it with pip install pettingzoo.
4
Understand the multi-agent data model
5
In TorchRL, MARL environments return TensorDicts where agent-specific data lives under a named group key. For VMAS with the default "agents" group:
6
TensorDict(
    agents: TensorDict(
        observation: Tensor(shape=[n_envs, n_agents, obs_dim])
        action:      Tensor(shape=[n_envs, n_agents, act_dim])
        reward:      Tensor(shape=[n_envs, n_agents, 1])
        done:        Tensor(shape=[n_envs, n_agents, 1])
    )
    done: Tensor(shape=[n_envs, 1])          # team-level done (may also be per-agent)
)
7
Nested keys are expressed as tuples: ("agents", "observation"), ("agents", "action"). The group name and agent list are configured by the group_map argument when constructing the environment.
8
Create a VMAS environment
9
VmasEnv wraps VMAS scenarios and handles the group-map automatically. The num_envs argument controls vectorisation—VMAS simulates all environments in a single GPU tensor batch.
10
import torch
from torchrl.envs import RewardSum, TransformedEnv, VmasEnv

device = "cuda" if torch.cuda.is_available() else "cpu"
N_ENVS = 60       # vectorised environments
MAX_STEPS = 100   # episode horizon

base_env = VmasEnv(
    scenario="navigation",     # cooperative navigation task
    num_envs=N_ENVS,
    continuous_actions=True,
    max_steps=MAX_STEPS,
    device=device,
    seed=0,
)
# Add per-agent episode reward accumulation for logging
env = TransformedEnv(
    base_env,
    RewardSum(
        in_keys=[base_env.reward_key],       # e.g. ("agents", "reward")
        out_keys=[("agents", "episode_reward")],
    ),
)

print("n_agents:", env.n_agents)
print("action_key:", env.action_key)            # ("agents", "action")
print("reward_key:", env.reward_key)            # ("agents", "reward")
print("obs spec:", env.observation_spec["agents", "observation"])
11
env.action_key and env.reward_key are convenience properties that return the correct nested key for the first agent group. Use them instead of hard-coding strings so your code is portable across different VMAS scenarios.
12
Build the decentralised actor
13
Each agent runs its own policy conditioned only on its local observation. MultiAgentMLP constructs a network with shape (n_agents, obs_dim) → (n_agents, 2 * act_dim) where the factor of 2 comes from outputting both loc and scale for the Gaussian policy.
14
from torch import nn
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

obs_dim = env.observation_spec["agents", "observation"].shape[-1]
action_dim = env.action_spec.shape[-1]

# Decentralised: centralized=False means each agent only sees its own obs
actor_backbone = nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs=obs_dim,
        n_agent_outputs=2 * action_dim,   # loc + scale for TanhNormal
        n_agents=env.n_agents,
        centralized=False,               # IPPO / MAPPO: decentralised actor
        share_params=True,               # all agents share weights (parameter sharing)
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    ),
    NormalParamExtractor(),              # splits output into (loc, scale)
)

actor_module = TensorDictModule(
    actor_backbone,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "loc"), ("agents", "scale")],
)

actor = ProbabilisticActor(
    module=actor_module,
    in_keys=[("agents", "loc"), ("agents", "scale")],
    out_keys=[env.action_key],           # ("agents", "action")
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.full_action_spec_unbatched[("agents", "action")].space.low,
        "high": env.full_action_spec_unbatched[("agents", "action")].space.high,
    },
    return_log_prob=True,
)
15
Build the centralised critic (MAPPO) or decentralised critic (IPPO)
16
MAPPO uses a centralised critic: a single value function that observes the concatenated observations of all agents. IPPO uses a decentralised critic: each agent maintains its own independent value estimate.
17
MAPPO — Centralised Critic
from torchrl.modules import MultiAgentMLP
from tensordict.nn import TensorDictModule

# centralized=True: the MLP sees the concatenated observations of all agents
mappo_critic = TensorDictModule(
    MultiAgentMLP(
        n_agent_inputs=obs_dim,
        n_agent_outputs=1,
        n_agents=env.n_agents,
        centralized=True,        # <-- key difference: global state
        share_params=True,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    ),
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "state_value")],
)
IPPO — Decentralised Critic
# centralized=False: each agent estimates its own value independently
ippo_critic = TensorDictModule(
    MultiAgentMLP(
        n_agent_inputs=obs_dim,
        n_agent_outputs=1,
        n_agents=env.n_agents,
        centralized=False,       # <-- independent critics
        share_params=True,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    ),
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "state_value")],
)
18
Construct MAPPOLoss or IPPOLoss
19
MAPPOLoss and IPPOLoss are thin specialisations of ClipPPOLoss that default to MultiAgentGAE as their value estimator. set_keys maps the loss to the correct nested keys for the VMAS environment.
20
from torchrl.modules import PopArtValueNorm
from torchrl.objectives import MAPPOLoss, IPPOLoss

# Choose algorithm
USE_MAPPO = True
LossCls = MAPPOLoss if USE_MAPPO else IPPOLoss
critic = mappo_critic if USE_MAPPO else ippo_critic

# PopArt normalises value targets to stabilise critic training
value_norm = PopArtValueNorm(shape=1, device=device) if USE_MAPPO else None

loss_module = LossCls(
    actor_network=actor,
    critic_network=critic,
    value_norm=value_norm,
    clip_epsilon=0.2,
    entropy_coeff=0.01,
    critic_coeff=1.0,
)

# Tell the loss where VMAS writes its per-agent signals
loss_module.set_keys(
    value=("agents", "state_value"),
    action=env.action_key,               # ("agents", "action")
    reward=env.reward_key,               # ("agents", "reward")
    done=("agents", "done"),
    terminated=("agents", "terminated"),
)
21
MAPPOLoss and IPPOLoss default to MultiAgentGAE for the value estimator. MultiAgentGAE broadcasts team-level done/reward tensors along the agent dimension before computing returns, which is needed when the environment shares a single done flag across all agents.
22
Set up the collector and replay buffer
23
The Collector and TensorDictReplayBuffer work identically to the single-agent case—multi-agent nesting is transparent.
24
from torchrl.collectors import Collector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

FRAMES_PER_BATCH = 6_000
MINIBATCH_SIZE = 400
TOTAL_FRAMES = 200_000

collector = Collector(
    env,
    actor,
    device=device,
    storing_device=device,
    frames_per_batch=FRAMES_PER_BATCH,
    total_frames=TOTAL_FRAMES,
)

replay_buffer = TensorDictReplayBuffer(
    storage=LazyTensorStorage(FRAMES_PER_BATCH, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=MINIBATCH_SIZE,
)
25
Training loop
26
import time

optim = torch.optim.Adam(loss_module.parameters(), lr=3e-4)
EPOCHS = 4

start = time.time()
for it, td in enumerate(collector):
    # Compute value targets and advantages with MultiAgentGAE
    with torch.no_grad():
        loss_module.value_estimator(
            td,
            params=loss_module.critic_network_params,
            target_params=loss_module.target_critic_network_params,
        )

    # Fill replay buffer and run PPO epochs
    replay_buffer.extend(td.reshape(-1))

    for _ in range(EPOCHS):
        for _ in range(FRAMES_PER_BATCH // MINIBATCH_SIZE):
            subdata = replay_buffer.sample()
            losses = loss_module(subdata)
            loss = (
                losses["loss_objective"]
                + losses["loss_critic"]
                + losses["loss_entropy"]
            )
            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0)
            optim.step()

    # Push updated policy weights to the collector
    collector.update_policy_weights_()

    # Log per-agent episode reward
    ep_reward = td.get(("next", "agents", "episode_reward")).mean().item()
    print(
        f"iter={it:03d}  frames={it * FRAMES_PER_BATCH:>7d}  "
        f"reward={ep_reward:+.3f}  elapsed={time.time() - start:.1f}s"
    )

collector.shutdown()
env.close()
27
Using PettingZooEnv for competitive / mixed settings
28
from torchrl.envs import PettingZooEnv

env = PettingZooEnv(
    task="simple_spread_v3",    # from PettingZoo MPE
    parallel=True,              # parallel API (all agents step at once)
    continuous_actions=True,
    seed=0,
)
print(env.group_map)           # dict mapping group name to list of agent names

MAPPO vs IPPO: when to use each

Multi-Agent PPO with a centralised critic. The critic observes the concatenated observations of all agents, giving it a global view of the system.
  • Typically outperforms IPPO on cooperative tasks where rewards depend on joint behaviour.
  • Higher memory cost at training time (the critic processes n_agents × obs_dim inputs).
  • Policies remain decentralised at execution time—each agent only sees its local observation.
  • Use PopArtValueNorm to stabilise critic training when reward magnitudes vary.
from torchrl.objectives import MAPPOLoss
loss = MAPPOLoss(actor_network=actor, critic_network=centralized_critic, ...)

Nested key reference

DataTensorDict keyShape
Agent observations("agents", "observation")(n_envs, n_agents, obs_dim)
Agent actions("agents", "action")(n_envs, n_agents, act_dim)
Per-agent rewards("agents", "reward")(n_envs, n_agents, 1)
Per-agent values("agents", "state_value")(n_envs, n_agents, 1)
Per-agent advantages("agents", "advantage")(n_envs, n_agents, 1)
Per-agent done("agents", "done")(n_envs, n_agents, 1)
Team-level done"done"(n_envs, 1)

Build docs developers (and LLMs) love