Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

TorchRL’s LLM post-training stack brings the same composable, data-model-first design philosophy to language model fine-tuning. Rather than treating RLHF or GRPO as a bespoke pipeline, TorchRL models the generation process as a TensorDict-passing loop: a conversation environment produces prompts, a policy module generates responses, reward transforms score them, and a standard GRPOLoss or SFTLoss module closes the training loop. HuggingFace Transformers, vLLM, and SGLang backends are all supported through a common LLMWrapperBase interface, making it easy to swap fast inference engines into the same training code.
LLM post-training in TorchRL requires several optional dependencies. Before running any code in this guide install the full dependency set:
pip install torchrl[llm]
# For GRPO training with vLLM:
pip install vllm ray transformers peft accelerate datasets wandb
# For SGLang:
pip install sglang[all]

The LLM Data Model

TorchRL represents a conversation as a TensorDict whose leaves are either token tensors or non-tensor NonTensorData objects (for raw strings and History objects). The two key abstractions are:
  • History — a structured chat history object that wraps a list of {"role": ..., "content": ...} dicts and provides apply_chat_template() for tokenisation.
  • ChatEnv — an EnvBase subclass where each step() appends the model’s response tokens to the running conversation.
from torchrl.envs.llm import ChatEnv
from torchrl.data.llm.history import History

# History-mode environment (default): observations are History objects
env = ChatEnv.from_dataloader(
    dataloader=my_dataloader,
    input_mode="history",   # "history" | "text" | "tokens"
    batch_size=[8],
)
Key tensor keys in a typical ChatEnv step:
KeyDescription
"history"Conversation history (History object, input_mode="history")
"text"Current prompt string (input_mode="text")
"text_response"Model response string
"tokens"Prompt token IDs (input_mode="tokens")
"tokens_response"Response token IDs
"attention_mask"Attention mask
"reward"Scalar reward assigned by reward transforms

LLM Policy Wrappers

All policy backends implement LLMWrapperBase and share the same __call__(tensordict) interface.
from torchrl.modules.llm import TransformersWrapper
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

# Inference wrapper (generate=True) for data collection
policy_inference = TransformersWrapper(
    model,
    tokenizer=tokenizer,
    generate=True,
    from_text=True,
    return_log_probs=True,
)

# Training wrapper (generate=False) for gradient computation
policy_train = TransformersWrapper(
    model,
    tokenizer=tokenizer,
    generate=False,
    from_text=True,
    return_log_probs=True,
)

LLM Collectors and Weight Synchronisation

LLMCollector is a lightweight Collector subclass designed for auto-regressive generation. It handles variable-length responses and can write directly into a ReplayBuffer or yield batches.
from torchrl.collectors.llm import LLMCollector, RayLLMCollector
from torchrl.data import ReplayBuffer, LazyStackStorage
from torchrl.collectors import VanillaWeightUpdater

replay_buffer = ReplayBuffer(
    storage=LazyStackStorage(max_size=10_000),
    batch_size=64,
)

# Single-process collector
collector = LLMCollector(
    env=env,
    policy=policy_inference,
    dialog_turns_per_batch=128,
    total_dialog_turns=100_000,
    replay_buffer=replay_buffer,
    flatten_data=True,
)

# Distributed Ray collector for multi-GPU setups
collector = RayLLMCollector(
    env=env,
    policy=async_policy,
    dialog_turns_per_batch=512,
    replay_buffer=replay_buffer,
)
Weight synchronisation between the training model and the inference engine is handled by WeightUpdaterBase subclasses. The GRPO reference implementation uses a VLLMWeightSyncScheme or SGLangWeightSyncScheme:
from torchrl.weight_update.llm import VLLMWeightSyncScheme, SGLangWeightSyncScheme

# Create a vLLM weight sync scheme
weight_sync = VLLMWeightSyncScheme(engine=vllm_engine)
sender = weight_sync.create_sender()
sender.register_model(policy_train.model)  # HuggingFace model, not the wrapper
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)

# Call after each optimizer step to push weights to inference workers
sender.update_weights()

GRPO and SFT Objectives

GRPOLoss

GRPOLoss implements Group Relative Policy Optimisation — a clipped importance-weighted objective similar to PPO, adapted for token-level LLM training.
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage

loss_fn = GRPOLoss(
    actor_network=policy_train,
    clip_epsilon=0.2,                  # symmetric clipping [0.8, 1.2]
    # DAPO asymmetric clipping: clip_epsilon=(0.20, 0.28)
    entropy_bonus=True,
    entropy_coeff=0.01,
    kl_to_ref_coeff=0.05,              # KL penalty toward reference policy
    kl_to_inference_coeff=None,        # optional KL toward inference model
    aggregation="token_mean",          # "token_mean" | "prompt_mean" | "none"
    masking_strategy="sft",            # "sft" | "rlhf" | "generic"
    device=torch.device("cuda:0"),
)

# Forward returns a GRPOLossOutput TensorClass
loss_out = loss_fn(batch)
# Fields: loss_objective, clip_fraction, kl_approx, ESS, entropy, loss_entropy, loss_kl_to_ref
loss_out.loss_objective.backward()
Masking strategies control which tokens contribute to the loss:
StrategyTokens includedUse case
"sft"Response tokens onlySingle-turn chat
"rlhf"Assistant tokens onlyMulti-turn conversations
"generic"All valid (non-pad) tokensCustom sequences
Keep the training model in eval() mode during GRPO optimisation. A mismatch between train() and eval() modes is the most common cause of importance-sampling instability — the dropout stochasticity makes the train-time log-probabilities inconsistent with the collection-time log-probabilities.
Monte-Carlo advantage (MCAdvantage) normalises group-relative rewards into the advantage estimates that GRPOLoss expects. It is implemented as a Transform that can be added to a replay buffer or applied inline:
from torchrl.objectives.llm.grpo import MCAdvantage
from torchrl.data import ReplayBuffer

# MCAdvantage as a replay buffer transform
buffer = ReplayBuffer(
    storage=LazyStackStorage(max_size=10_000),
    transform=MCAdvantage(grpo_size=8),   # normalise within groups of 8 responses
)

SFTLoss

SFTLoss implements supervised fine-tuning (negative log-likelihood on assistant tokens). It optionally adds a KL penalty toward a reference model to prevent over-fitting:
from torchrl.objectives.llm.sft import SFTLoss
from torchrl.envs.llm.transforms import RetrieveLogProb

# Reference model (frozen)
policy_ref = TransformersWrapper(ref_model, tokenizer=tokenizer, generate=False, return_log_probs=True)

# RetrieveLogProb transform fetches reference log-probs during data collection
kl_transform = RetrieveLogProb(
    policy=policy_ref,
    assistant_only=True,
    tokenizer=tokenizer,
)

sft_loss = SFTLoss(
    actor_network=policy_train,
    tokenizer=tokenizer,
    reduction="mean",
    normalize_by_seq_length=True,
    kl_to_ref_coeff=0.1,               # KL term weight; requires RetrieveLogProb
    loss_function="sft",               # "sft" or "minor_sft"
)

loss_out = sft_loss(batch)
# loss_out.loss_sft, loss_out.kl_to_ref

Reward Transforms and Tool Use

TorchRL provides a growing set of reward transforms that run inside the ChatEnv pipeline:
from torchrl.envs.llm import KLRewardTransform, RetrieveKL, AddThinkingPrompt

# KL penalty reward (subtracts β·KL from the environment reward)
env = env.append_transform(
    KLRewardTransform(ref_policy=policy_ref, coeff=0.05)
)

# Reasoning / chain-of-thought prompt injection
env = env.append_transform(AddThinkingPrompt())
For tool-use scenarios where the model calls external functions, transform stacks can parse the model output and route tool calls before returning the next observation.

GRPO Training Setup

The following is a condensed version of the synchronous GRPO training loop from sota-implementations/grpo/grpo-sync.py:
1
Initialise Models and Environment
2
import torch
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.envs.llm import GSM8KEnv  # or CountdownEnv, MATHEnv, IFEvalEnv

# Training model (HuggingFace)
policy_train = TransformersWrapper(
    model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct"),
    tokenizer=tokenizer,
    generate=False,
    return_log_probs=True,
)

# Inference model (vLLM for throughput)
policy_inference = vLLMWrapper(
    model="Qwen/Qwen2.5-7B-Instruct",
    generate=True,
)

env = GSM8KEnv(batch_size=[8], reasoning=True)
3
Create Loss, Collector, and Replay Buffer
4
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
from torchrl.collectors.llm import RayLLMCollector
from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement

loss_fn = GRPOLoss(
    actor_network=policy_train,
    kl_to_ref_coeff=0.05,
    entropy_coeff=0.01,
    masking_strategy="rlhf",
    device=torch.device("cuda:0"),
)

replay_buffer = ReplayBuffer(
    storage=LazyStackStorage(max_size=4096),
    sampler=SamplerWithoutReplacement(),
    batch_size=64,
)

collector = RayLLMCollector(
    env=env,
    policy=policy_inference,
    dialog_turns_per_batch=512,
    replay_buffer=replay_buffer,
)
5
Weight Sync and Training Loop
6
from torchrl.weight_update.llm import VLLMWeightSyncScheme

weight_sync = VLLMWeightSyncScheme(engine=policy_inference.model)
sender = weight_sync.create_sender()
sender.register_model(policy_train.model)
sender.init_all_workers_group(metadata, vllm_engine=policy_inference.model)
sender.update_weights()  # initial sync

optimizer = torch.optim.Adam(policy_train.parameters(), lr=1e-6)

for data in collector:
    if not len(replay_buffer):
        continue

    for batch in replay_buffer:
        loss_out = loss_fn(batch)
        optimizer.zero_grad()
        loss_out.loss_objective.backward()
        torch.nn.utils.clip_grad_norm_(policy_train.parameters(), 1.0)
        optimizer.step()

    # Push updated weights to vLLM workers
    sender.update_weights()

Pre-built Task Environments

TorchRL ships several ready-to-use LLM task environments under torchrl.envs.llm:
ClassTaskReward signal
GSM8KEnvGrade-school math (GSM8K)Exact answer match
MATHEnvCompetition math (MATH)Symbolic equivalence
CountdownEnvCountdown number puzzleSolution correctness
IFEvalEnvInstruction following (IFEval)Constraint satisfaction
from torchrl.envs.llm.datasets.math import MATHEnv
from torchrl.envs.llm.datasets.countdown import CountdownEnv
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv

env = MATHEnv(batch_size=[4], from_text=True)

Key Imports Reference

# Environments
from torchrl.envs.llm import ChatEnv, GSM8KEnv, KLRewardTransform, RetrieveKL, AddThinkingPrompt
from torchrl.envs.llm.datasets.math import MATHEnv
from torchrl.envs.llm.datasets.countdown import CountdownEnv
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv

# Policy wrappers
from torchrl.modules.llm import (
    LLMWrapperBase,
    TransformersWrapper,
    RemoteTransformersWrapper,
    vLLMWrapper,
    AsyncVLLM,
    make_async_vllm_engine,
    SGLangWrapper,
    AsyncSGLang,
    RLSGLangEngine,
)

# Collectors
from torchrl.collectors.llm import LLMCollector, RayLLMCollector

# Loss modules
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage, GRPOLossOutput
from torchrl.objectives.llm.sft import SFTLoss, SFTLossOutput

# Weight synchronisation
from torchrl.weight_update.llm import VLLMWeightSyncScheme, SGLangWeightSyncScheme

# Data
from torchrl.data.llm.history import History

Build docs developers (and LLMs) love