Skip to main content

Overview

The slime training loop follows a cyclical pattern of Data Sampling → Model Training → Weight Synchronization. This on-policy RL approach ensures that the model is always trained on data generated by its current policy.

Training Loop Implementation

From train.py:64-94, the main training loop:
for rollout_id in range(args.start_rollout_id, args.num_rollout):
    # 1. Evaluation (if scheduled)
    if args.eval_interval is not None and rollout_id == 0:
        ray.get(rollout_manager.eval.remote(rollout_id))
    
    # 2. Generate rollout data
    rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
    
    # 3. Offload rollout engines (optional)
    if args.offload_rollout:
        ray.get(rollout_manager.offload.remote())
    
    # 4. Train actor and/or critic
    if args.use_critic:
        critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)
        if rollout_id >= args.num_critic_only_steps:
            ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
        ray.get(critic_train_handle)
    else:
        ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
    
    # 5. Save checkpoints (if scheduled)
    if should_run_periodic_action(rollout_id, args.save_interval, ...):
        save(rollout_id)
    
    # 6. Sync weights to rollout engines
    offload_train(rollout_id)
    if args.offload_rollout:
        ray.get(rollout_manager.onload_weights.remote())
    actor_model.update_weights()
    
    # 7. Periodic evaluation
    if should_run_periodic_action(rollout_id, args.eval_interval, ...):
        ray.get(rollout_manager.eval.remote(rollout_id))

Batch Size Configuration

Critical Constraint:The following equation must hold:rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rolloutThis ensures that all generated data is consumed during training.

Phase One: Data Sampling (Rollout)

  • --rollout-batch-size: Number of prompts sampled per rollout
  • --n-samples-per-prompt: Number of responses generated per prompt
  • Total samples generated: rollout_batch_size × n_samples_per_prompt

Phase Two: Model Training

  • --global-batch-size: Sample size required for one optimizer.step()
  • --num-steps-per-rollout: Number of parameter updates per rollout (default: 1)
  • Total samples consumed: global_batch_size × num_steps_per_rollout

Example Configuration

ROLLOUT_ARGS=(
    --rollout-batch-size 16      # 16 prompts
    --n-samples-per-prompt 8     # 8 responses per prompt
    --num-steps-per-rollout 1    # 1 optimizer step
    --global-batch-size 128      # 128 samples per step
)
# 16 × 8 = 128 × 1 ✓
For off-policy training, set --num-steps-per-rollout > 1 to reuse the same data for multiple updates:
--rollout-batch-size 8
--n-samples-per-prompt 8
--num-steps-per-rollout 2
--global-batch-size 32
# 8 × 8 = 32 × 2 ✓
Note: slime defaults to on-policy training (num_steps_per_rollout=1) for better sample efficiency.

Rollout Process Control

Number of Rollouts

# Option 1: Specify number of rollouts directly
--num-rollout 3000

# Option 2: Specify number of epochs
--num-epoch 5
# num_rollout will be calculated as: dataset_size / rollout_batch_size * num_epoch

Training Steps

Each rollout performs num_steps_per_rollout optimizer steps:
# Total optimizer steps in training
total_steps = num_rollout × num_steps_per_rollout

# Total samples seen
total_samples = num_rollout × rollout_batch_size × n_samples_per_prompt

Loss Computation

From loss.py:400-561, slime computes advantages and returns based on the chosen algorithm:
def compute_advantages_and_returns(args, rollout_data):
    log_probs = rollout_data.get("log_probs")
    ref_log_probs = rollout_data.get("ref_log_probs")
    rewards = rollout_data.get("rewards")
    values = rollout_data.get("values")
    response_lengths = rollout_data.get("response_lengths")
    loss_masks = rollout_data.get("loss_masks")
    
    # Compute KL divergence
    if args.kl_coef == 0:
        kl = [torch.zeros_like(x) for x in log_probs]
    else:
        kl = [
            compute_approx_kl(log_probs[i], ref_log_probs[i], args.kl_loss_type)
            for i in range(len(log_probs))
        ]
    
    # Compute advantages based on algorithm
    if args.advantage_estimator == "grpo":
        rewards = torch.tensor(rewards, dtype=torch.float32)
        returns = get_grpo_returns(rewards, kl)
        advantages = [r for r in returns]
    
    elif args.advantage_estimator == "ppo":
        # PPO with value function
        rewards = []
        for reward, k in zip(old_rewards, kl):
            k *= -args.kl_coef
            k[-1] += reward  # Add terminal reward
            rewards.append(k)
        advantages, returns = get_advantages_and_returns_batch(
            total_lengths, response_lengths, values, rewards, 
            args.gamma, args.lambd
        )
    
    # Normalize advantages (optional)
    if args.normalize_advantages:
        all_advs = torch.cat(advantages)
        all_masks = torch.cat(loss_masks)
        whitened_advs = distributed_masked_whiten(
            all_advs, all_masks, process_group=dp_group
        )
        advantages = list(torch.split(whitened_advs, chunk_lengths))
    
    rollout_data["advantages"] = advantages
    rollout_data["returns"] = returns

Policy Loss

From loss.py:613-831, the policy loss uses PPO-style clipping:
def policy_loss_function(args, batch, logits, sum_of_sample_mean):
    # Compute current log probabilities
    _, log_probs_and_entropy = get_log_probs_and_entropy(
        logits,
        args=args,
        unconcat_tokens=batch["unconcat_tokens"],
        total_lengths=batch["total_lengths"],
        response_lengths=batch["response_lengths"],
        with_entropy=True,
    )
    
    log_probs = log_probs_and_entropy["log_probs"]
    old_log_probs = batch["log_probs"]
    advantages = torch.cat(batch["advantages"])
    
    # Compute PPO KL
    old_log_probs = torch.cat(old_log_probs)
    log_probs = torch.cat(log_probs)
    ppo_kl = old_log_probs - log_probs
    
    # Clipped policy loss
    ratio = (-ppo_kl).exp()
    pg_losses1 = -ratio * advantages
    pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages
    pg_loss = torch.maximum(pg_losses1, pg_losses2)
    
    pg_loss = sum_of_sample_mean(pg_loss)
    
    # Entropy bonus
    entropy = torch.cat(log_probs_and_entropy["entropy"])
    entropy_loss = sum_of_sample_mean(entropy)
    
    loss = pg_loss - args.entropy_coef * entropy_loss
    
    # KL penalty (optional)
    if args.use_kl_loss:
        ref_log_probs = torch.cat(batch["ref_log_probs"])
        kl = compute_approx_kl(log_probs, ref_log_probs, args.kl_loss_type)
        kl_loss = sum_of_sample_mean(kl)
        loss = loss + args.kl_loss_coef * kl_loss
    
    return loss, metrics

Dynamic Batch Size

slime supports dynamic batch sizing to improve GPU utilization with variable-length sequences.
From the quick start guide:
PERF_ARGS=(
    # Static micro-batch size (ignored when dynamic batching is enabled)
    # --micro-batch-size 1
    
    # Enable dynamic batching
    --use-dynamic-batch-size
    
    # Maximum tokens per GPU (controls batch size)
    --max-tokens-per-gpu 4608
)
How it works:
  1. slime packs samples into micro-batches such that total tokens ≈ max_tokens_per_gpu
  2. If a single sample exceeds the limit, it forms its own batch
  3. With context parallelism (CP), N CP cards share N × max_tokens_per_gpu tokens
  4. Loss calculation remains correct through proper masking

Weight Synchronization

After each training step, weights are synchronized from Megatron to SGLang:
# From train.py:88
actor_model.update_weights()
This involves:
  1. Serialization: Convert Megatron distributed checkpoint to a transferable format
  2. Transfer: Send weights to rollout manager via Ray
  3. Loading: Load weights into SGLang engines
  4. Verification (optional): Check weight equality with --check-weight-update-equal
For memory-constrained scenarios, use offloading:
# Offload rollout engines during training
--offload-rollout

# Offload training during rollout
--offload-train
From train.py:38-47:
def offload_train(rollout_id):
    if args.offload_train:
        if args.use_critic:
            critic_model.offload()
            if rollout_id >= args.num_critic_only_steps:
                actor_model.offload()
        else:
            actor_model.offload()
    else:
        actor_model.clear_memory()

Per-Sample vs Per-Token Loss

slime supports both loss calculation modes:
# Per-sample loss (default)
# loss = mean(sum(sample_i) / len(sample_i))

# Per-token loss
--calculate-per-token-loss
# loss = sum(sum(sample_i)) / sum(len(sample_i))

Algorithms

Learn about GRPO, PPO, GSPO, and Reinforce++

Rollout & Reward

Understand data generation and reward models

Build docs developers (and LLMs) love