Overview
The slime training loop follows a cyclical pattern of Data Sampling → Model Training → Weight Synchronization . This on-policy RL approach ensures that the model is always trained on data generated by its current policy.
Training Loop Implementation
From train.py:64-94, the main training loop:
for rollout_id in range (args.start_rollout_id, args.num_rollout):
# 1. Evaluation (if scheduled)
if args.eval_interval is not None and rollout_id == 0 :
ray.get(rollout_manager.eval.remote(rollout_id))
# 2. Generate rollout data
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
# 3. Offload rollout engines (optional)
if args.offload_rollout:
ray.get(rollout_manager.offload.remote())
# 4. Train actor and/or critic
if args.use_critic:
critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)
if rollout_id >= args.num_critic_only_steps:
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
ray.get(critic_train_handle)
else :
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
# 5. Save checkpoints (if scheduled)
if should_run_periodic_action(rollout_id, args.save_interval, ... ):
save(rollout_id)
# 6. Sync weights to rollout engines
offload_train(rollout_id)
if args.offload_rollout:
ray.get(rollout_manager.onload_weights.remote())
actor_model.update_weights()
# 7. Periodic evaluation
if should_run_periodic_action(rollout_id, args.eval_interval, ... ):
ray.get(rollout_manager.eval.remote(rollout_id))
Batch Size Configuration
Critical Constraint: The following equation must hold: rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rolloutThis ensures that all generated data is consumed during training.
Phase One: Data Sampling (Rollout)
--rollout-batch-size: Number of prompts sampled per rollout
--n-samples-per-prompt: Number of responses generated per prompt
Total samples generated : rollout_batch_size × n_samples_per_prompt
Phase Two: Model Training
--global-batch-size: Sample size required for one optimizer.step()
--num-steps-per-rollout: Number of parameter updates per rollout (default: 1)
Total samples consumed : global_batch_size × num_steps_per_rollout
Example Configuration
ROLLOUT_ARGS = (
--rollout-batch-size 16 # 16 prompts
--n-samples-per-prompt 8 # 8 responses per prompt
--num-steps-per-rollout 1 # 1 optimizer step
--global-batch-size 128 # 128 samples per step
)
# 16 × 8 = 128 × 1 ✓
Advanced: Off-Policy Training
For off-policy training, set --num-steps-per-rollout > 1 to reuse the same data for multiple updates: --rollout-batch-size 8
--n-samples-per-prompt 8
--num-steps-per-rollout 2
--global-batch-size 32
# 8 × 8 = 32 × 2 ✓
Note: slime defaults to on-policy training (num_steps_per_rollout=1) for better sample efficiency.
Rollout Process Control
Number of Rollouts
# Option 1: Specify number of rollouts directly
--num-rollout 3000
# Option 2: Specify number of epochs
--num-epoch 5
# num_rollout will be calculated as: dataset_size / rollout_batch_size * num_epoch
Training Steps
Each rollout performs num_steps_per_rollout optimizer steps:
# Total optimizer steps in training
total_steps = num_rollout × num_steps_per_rollout
# Total samples seen
total_samples = num_rollout × rollout_batch_size × n_samples_per_prompt
Loss Computation
From loss.py:400-561, slime computes advantages and returns based on the chosen algorithm:
def compute_advantages_and_returns ( args , rollout_data ):
log_probs = rollout_data.get( "log_probs" )
ref_log_probs = rollout_data.get( "ref_log_probs" )
rewards = rollout_data.get( "rewards" )
values = rollout_data.get( "values" )
response_lengths = rollout_data.get( "response_lengths" )
loss_masks = rollout_data.get( "loss_masks" )
# Compute KL divergence
if args.kl_coef == 0 :
kl = [torch.zeros_like(x) for x in log_probs]
else :
kl = [
compute_approx_kl(log_probs[i], ref_log_probs[i], args.kl_loss_type)
for i in range ( len (log_probs))
]
# Compute advantages based on algorithm
if args.advantage_estimator == "grpo" :
rewards = torch.tensor(rewards, dtype = torch.float32)
returns = get_grpo_returns(rewards, kl)
advantages = [r for r in returns]
elif args.advantage_estimator == "ppo" :
# PPO with value function
rewards = []
for reward, k in zip (old_rewards, kl):
k *= - args.kl_coef
k[ - 1 ] += reward # Add terminal reward
rewards.append(k)
advantages, returns = get_advantages_and_returns_batch(
total_lengths, response_lengths, values, rewards,
args.gamma, args.lambd
)
# Normalize advantages (optional)
if args.normalize_advantages:
all_advs = torch.cat(advantages)
all_masks = torch.cat(loss_masks)
whitened_advs = distributed_masked_whiten(
all_advs, all_masks, process_group = dp_group
)
advantages = list (torch.split(whitened_advs, chunk_lengths))
rollout_data[ "advantages" ] = advantages
rollout_data[ "returns" ] = returns
Policy Loss
From loss.py:613-831, the policy loss uses PPO-style clipping:
def policy_loss_function ( args , batch , logits , sum_of_sample_mean ):
# Compute current log probabilities
_, log_probs_and_entropy = get_log_probs_and_entropy(
logits,
args = args,
unconcat_tokens = batch[ "unconcat_tokens" ],
total_lengths = batch[ "total_lengths" ],
response_lengths = batch[ "response_lengths" ],
with_entropy = True ,
)
log_probs = log_probs_and_entropy[ "log_probs" ]
old_log_probs = batch[ "log_probs" ]
advantages = torch.cat(batch[ "advantages" ])
# Compute PPO KL
old_log_probs = torch.cat(old_log_probs)
log_probs = torch.cat(log_probs)
ppo_kl = old_log_probs - log_probs
# Clipped policy loss
ratio = ( - ppo_kl).exp()
pg_losses1 = - ratio * advantages
pg_losses2 = - ratio.clamp( 1 - eps_clip, 1 + eps_clip_high) * advantages
pg_loss = torch.maximum(pg_losses1, pg_losses2)
pg_loss = sum_of_sample_mean(pg_loss)
# Entropy bonus
entropy = torch.cat(log_probs_and_entropy[ "entropy" ])
entropy_loss = sum_of_sample_mean(entropy)
loss = pg_loss - args.entropy_coef * entropy_loss
# KL penalty (optional)
if args.use_kl_loss:
ref_log_probs = torch.cat(batch[ "ref_log_probs" ])
kl = compute_approx_kl(log_probs, ref_log_probs, args.kl_loss_type)
kl_loss = sum_of_sample_mean(kl)
loss = loss + args.kl_loss_coef * kl_loss
return loss, metrics
Dynamic Batch Size
slime supports dynamic batch sizing to improve GPU utilization with variable-length sequences.
From the quick start guide:
PERF_ARGS = (
# Static micro-batch size (ignored when dynamic batching is enabled)
# --micro-batch-size 1
# Enable dynamic batching
--use-dynamic-batch-size
# Maximum tokens per GPU (controls batch size)
--max-tokens-per-gpu 4608
)
How it works:
slime packs samples into micro-batches such that total tokens ≈ max_tokens_per_gpu
If a single sample exceeds the limit, it forms its own batch
With context parallelism (CP), N CP cards share N × max_tokens_per_gpu tokens
Loss calculation remains correct through proper masking
Weight Synchronization
After each training step, weights are synchronized from Megatron to SGLang:
# From train.py:88
actor_model.update_weights()
This involves:
Serialization : Convert Megatron distributed checkpoint to a transferable format
Transfer : Send weights to rollout manager via Ray
Loading : Load weights into SGLang engines
Verification (optional): Check weight equality with --check-weight-update-equal
For memory-constrained scenarios, use offloading: # Offload rollout engines during training
--offload-rollout
# Offload training during rollout
--offload-train
From train.py:38-47: def offload_train ( rollout_id ):
if args.offload_train:
if args.use_critic:
critic_model.offload()
if rollout_id >= args.num_critic_only_steps:
actor_model.offload()
else :
actor_model.offload()
else :
actor_model.clear_memory()
Per-Sample vs Per-Token Loss
slime supports both loss calculation modes:
# Per-sample loss (default)
# loss = mean(sum(sample_i) / len(sample_i))
# Per-token loss
--calculate-per-token-loss
# loss = sum(sum(sample_i)) / sum(len(sample_i))
Algorithms Learn about GRPO, PPO, GSPO, and Reinforce++
Rollout & Reward Understand data generation and reward models