TorchRL’s LLM post-training stack brings the same composable, data-model-first design philosophy to language model fine-tuning. Rather than treating RLHF or GRPO as a bespoke pipeline, TorchRL models the generation process as a TensorDict-passing loop: a conversation environment produces prompts, a policy module generates responses, reward transforms score them, and a standardDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt
Use this file to discover all available pages before exploring further.
GRPOLoss or SFTLoss module closes the training loop. HuggingFace Transformers, vLLM, and SGLang backends are all supported through a common LLMWrapperBase interface, making it easy to swap fast inference engines into the same training code.
The LLM Data Model
TorchRL represents a conversation as aTensorDict whose leaves are either token tensors or non-tensor NonTensorData objects (for raw strings and History objects). The two key abstractions are:
History— a structured chat history object that wraps a list of{"role": ..., "content": ...}dicts and providesapply_chat_template()for tokenisation.ChatEnv— anEnvBasesubclass where eachstep()appends the model’s response tokens to the running conversation.
ChatEnv step:
| Key | Description |
|---|---|
"history" | Conversation history (History object, input_mode="history") |
"text" | Current prompt string (input_mode="text") |
"text_response" | Model response string |
"tokens" | Prompt token IDs (input_mode="tokens") |
"tokens_response" | Response token IDs |
"attention_mask" | Attention mask |
"reward" | Scalar reward assigned by reward transforms |
LLM Policy Wrappers
All policy backends implementLLMWrapperBase and share the same __call__(tensordict) interface.
- HuggingFace Transformers
- vLLM
- SGLang
LLM Collectors and Weight Synchronisation
LLMCollector is a lightweight Collector subclass designed for auto-regressive generation. It handles variable-length responses and can write directly into a ReplayBuffer or yield batches.
WeightUpdaterBase subclasses. The GRPO reference implementation uses a VLLMWeightSyncScheme or SGLangWeightSyncScheme:
GRPO and SFT Objectives
GRPOLoss
GRPOLoss implements Group Relative Policy Optimisation — a clipped importance-weighted objective similar to PPO, adapted for token-level LLM training.
| Strategy | Tokens included | Use case |
|---|---|---|
"sft" | Response tokens only | Single-turn chat |
"rlhf" | Assistant tokens only | Multi-turn conversations |
"generic" | All valid (non-pad) tokens | Custom sequences |
Keep the training model in
eval() mode during GRPO optimisation. A mismatch between train() and eval() modes is the most common cause of importance-sampling instability — the dropout stochasticity makes the train-time log-probabilities inconsistent with the collection-time log-probabilities.MCAdvantage) normalises group-relative rewards into the advantage estimates that GRPOLoss expects. It is implemented as a Transform that can be added to a replay buffer or applied inline:
SFTLoss
SFTLoss implements supervised fine-tuning (negative log-likelihood on assistant tokens). It optionally adds a KL penalty toward a reference model to prevent over-fitting:
Reward Transforms and Tool Use
TorchRL provides a growing set of reward transforms that run inside theChatEnv pipeline:
GRPO Training Setup
The following is a condensed version of the synchronous GRPO training loop fromsota-implementations/grpo/grpo-sync.py:
import torch
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.envs.llm import GSM8KEnv # or CountdownEnv, MATHEnv, IFEvalEnv
# Training model (HuggingFace)
policy_train = TransformersWrapper(
model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct"),
tokenizer=tokenizer,
generate=False,
return_log_probs=True,
)
# Inference model (vLLM for throughput)
policy_inference = vLLMWrapper(
model="Qwen/Qwen2.5-7B-Instruct",
generate=True,
)
env = GSM8KEnv(batch_size=[8], reasoning=True)
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
from torchrl.collectors.llm import RayLLMCollector
from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
loss_fn = GRPOLoss(
actor_network=policy_train,
kl_to_ref_coeff=0.05,
entropy_coeff=0.01,
masking_strategy="rlhf",
device=torch.device("cuda:0"),
)
replay_buffer = ReplayBuffer(
storage=LazyStackStorage(max_size=4096),
sampler=SamplerWithoutReplacement(),
batch_size=64,
)
collector = RayLLMCollector(
env=env,
policy=policy_inference,
dialog_turns_per_batch=512,
replay_buffer=replay_buffer,
)
from torchrl.weight_update.llm import VLLMWeightSyncScheme
weight_sync = VLLMWeightSyncScheme(engine=policy_inference.model)
sender = weight_sync.create_sender()
sender.register_model(policy_train.model)
sender.init_all_workers_group(metadata, vllm_engine=policy_inference.model)
sender.update_weights() # initial sync
optimizer = torch.optim.Adam(policy_train.parameters(), lr=1e-6)
for data in collector:
if not len(replay_buffer):
continue
for batch in replay_buffer:
loss_out = loss_fn(batch)
optimizer.zero_grad()
loss_out.loss_objective.backward()
torch.nn.utils.clip_grad_norm_(policy_train.parameters(), 1.0)
optimizer.step()
# Push updated weights to vLLM workers
sender.update_weights()
Pre-built Task Environments
TorchRL ships several ready-to-use LLM task environments undertorchrl.envs.llm:
| Class | Task | Reward signal |
|---|---|---|
GSM8KEnv | Grade-school math (GSM8K) | Exact answer match |
MATHEnv | Competition math (MATH) | Symbolic equivalence |
CountdownEnv | Countdown number puzzle | Solution correctness |
IFEvalEnv | Instruction following (IFEval) | Constraint satisfaction |