This guide covers all command-line parameters and configuration options for running slime training jobs.
Basic Training Command
All slime training jobs are launched through train.py:
python train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
# ... additional parameters
Parameter Categories
slime parameters fall into three categories:
Megatron Arguments All Megatron-LM training parameters (e.g., --tensor-model-parallel-size)
SGLang Arguments SGLang inference parameters with --sglang- prefix (e.g., --sglang-mem-fraction-static)
slime Arguments Framework-specific parameters for RL training (e.g., --advantage-estimator)
Cluster Resource Allocation
Configure GPU allocation for training and inference:
Standard Configuration (Separate Resources)
python train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 4 \
--rollout-num-gpus 4 \
--rollout-num-gpus-per-engine 2
Actor : 4 GPUs for training
Rollout : 4 GPUs for inference (2 engines × 2 GPUs each)
Total : 8 GPUs
Colocated Configuration (Shared Resources)
python train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--colocate \
--sglang-mem-fraction-static 0.8
Shared : 8 GPUs used for both training and inference
Training and inference run sequentially on the same hardware
In colocated mode, reduce SGLang memory usage with --sglang-mem-fraction-static 0.8 to prevent OOM errors.
Key Parameters
Parameter Description --actor-num-nodesNumber of nodes for RL actor training --actor-num-gpus-per-nodeGPUs per node for actor --rollout-num-gpusTotal GPUs for rollout (ignored with --colocate) --rollout-num-gpus-per-engineGPUs per inference engine (similar to SGLang’s tp_size) --colocateEnable shared training/inference on same GPUs --prefill-num-serversServers for prefill in PD disaggregation
Training Backend Selection
--train-backend megatron # Default: Megatron-LM
--train-backend fsdp # Experimental: PyTorch FSDP
FSDP backend is experimental and allows direct loading of Hugging Face weights without conversion.
Model Configuration
Loading Model Parameters
Megatron requires explicit model architecture configuration. Load from provided templates:
source scripts/models/glm4-9B.sh # For GLM-4 9B
source scripts/models/qwen3-4B.sh # For Qwen3 4B
source scripts/models/qwen3-30b-a3b.sh # For Qwen3 30B MoE
Or define manually:
MODEL_ARGS = (
--num-layers 36
--hidden-size 2560
--ffn-hidden-size 9728
--swiglu
--vocab-size 151936
--disable-bias-linear
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--kv-channels 128
--qk-layernorm
--normalization "RMSNorm"
--norm-epsilon 1e-6
--use-rotary-position-embeddings
--rotary-base 1000000
)
Find all Megatron parameters and descriptions in your installed Megatron’s argument parser.
Checkpoint Configuration
CKPT_ARGS = (
--hf-checkpoint /path/to/hf_model # For tokenizer/config
--ref-load /path/to/ref_torch_dist # Reference model (Megatron)
--load /path/to/actor_checkpoint # Actor resume checkpoint
--save /path/to/save_dir # Save path
--save-interval 20 # Save every N rollouts
--ckpt-format torch_dist # Checkpoint format
)
Checkpoint Loading Logic
If --load points to a valid checkpoint → resume from there
If --load is empty or invalid → initialize actor from --ref-load
Reference model always loads from --ref-load
torch_dist format supports automatic parallel sharding - different parallelism configs can share the same checkpoint.
Parallelism Configuration
Parallelism Strategies
PERF_ARGS = (
# Tensor parallelism
--tensor-model-parallel-size 2
--sequence-parallel # Enable with TP (recommended)
# Pipeline parallelism
--pipeline-model-parallel-size 1
# Context parallelism (ring attention)
--context-parallel-size 2
# MoE parallelism
--expert-model-parallel-size 1 # Expert parallelism
--expert-tensor-parallel-size 1 # Expert TP (can differ from model TP)
)
Recomputation (Activation Checkpointing)
PERF_ARGS += (
--recompute-granularity full # 'full' or 'selective'
--recompute-method uniform # Recomputation method
--recompute-num-layers 1 # Layers per recomputation group
)
Dynamic Batching
PERF_ARGS += (
--use-dynamic-batch-size # Enable intelligent batch packing
--max-tokens-per-gpu 4608 # Max tokens per GPU
)
Dynamic batching is strongly recommended. It improves efficiency without affecting loss calculation accuracy.
Data Configuration
slime expects .jsonl files where each line is a JSON object:
{
"prompt" : [
{
"content" : "Solve this problem: ..." ,
"role" : "user" ,
"step_loss_mask" : 1
}
],
"label" : "42" ,
"metadata" : { "session_id" : "abc123" , "difficulty" : "hard" }
}
Data Loading Parameters
ROLLOUT_ARGS = (
--prompt-data /path/to/train.jsonl
--input-key prompt # Field containing input
--label-key label # Field containing labels
--metadata-key metadata # Field with extra info
--apply-chat-template # Apply chat template to prompts
--rollout-shuffle # Shuffle data each rollout
)
step_loss_mask=0 excludes a turn from SFT loss. metadata can store arbitrary JSON for custom functions.
Rollout Configuration
Core Rollout Parameters
The training loop alternates between rollout (data generation) and training (weight update):
ROLLOUT_ARGS = (
--num-rollout 3000 # Total rollout iterations
--rollout-batch-size 16 # Prompts sampled per rollout
--n-samples-per-prompt 8 # Responses per prompt
--num-steps-per-rollout 1 # Training steps per rollout
--global-batch-size 128 # Samples per optimizer.step()
)
Critical constraint : (rollout-batch-size × n-samples-per-prompt) = (global-batch-size × num-steps-per-rollout)Example: (16 × 8) = (128 × 1) ✓
Sampling Parameters
ROLLOUT_ARGS += (
--rollout-max-response-len 8192 # Max tokens in response
--rollout-temperature 1.0 # Sampling temperature
--rollout-top-p 1.0 # Nucleus sampling
--balance-data # Balance load across DP ranks
)
Reward Model Configuration
ROLLOUT_ARGS += (
--rm-type deepscaler # Built-in reward model type
# or
--custom-rm-path path.to.module:reward_function
)
Custom Reward Function
Implement a custom reward function:
async def reward_func ( args , sample : Sample, ** kwargs ) -> float :
# Evaluate the sample and return a reward score
return score
Specify with --custom-rm-path mymodule.rewards:reward_func
Evaluation Configuration
EVAL_ARGS = (
--eval-interval 5 # Evaluate every N rollouts
--eval-prompt-data aime /path/to/eval.jsonl
--n-samples-per-eval-prompt 16 # Samples per eval prompt
--eval-max-response-len 16384 # Max response length
--eval-top-p 1.0 # Sampling for evaluation
)
RL Algorithm Configuration
GRPO (Group Relative Policy Optimization)
GRPO_ARGS = (
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00 # KL penalty (0 = monitor only)
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
--normalize-advantages # Normalize advantage values
)
PPO (Proximal Policy Optimization)
PPO_ARGS = (
--advantage-estimator ppo
# Critic configuration
--critic-num-nodes 1
--critic-num-gpus-per-node 4
--critic-load /path/to/critic_ckpt
--critic-save /path/to/critic_save
--critic-lr 1e-5
--critic-lr-warmup-iters 100
--num-critic-only-steps 10 # Train critic only first N steps
# PPO hyperparameters
--eps-clip 0.2
--value-clip 0.2
--kl-coef 0.01
)
PPO requires additional GPUs for the critic model that run in parallel with the actor.
Other Algorithms
--advantage-estimator gspo # GSPO
--advantage-estimator reinforce_plus_plus # Reinforce++
--advantage-estimator reinforce_plus_plus_baseline # Reinforce++ Baseline
Algorithm Modifiers
--use-opd # On-policy distillation
--opd-kl-coef 0.01 # OPD KL coefficient
--use-tis # Truncated importance sampling
--calculate-per-token-loss # Per-token instead of per-sample loss
--true-on-policy-mode # Strict on-policy training
Optimizer Configuration
OPTIMIZER_ARGS = (
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--lr-warmup-iters 0
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
--clip-grad 1.0
)
SGLang Configuration
All SGLang parameters use the --sglang- prefix:
SGLANG_ARGS = (
--sglang-mem-fraction-static 0.9 # GPU memory fraction
--sglang-context-length 32768 # Max context length
--sglang-log-level INFO # Logging level
--sglang-enable-dp-attention # DP attention for MoE
--sglang-dp-size 2 # DP size
--sglang-ep-size 4 # EP size for MoE
--sglang-moe-a2a-backend deepep # MoE all-to-all backend
)
Router Configuration
slime uses sgl-router for load balancing:
--sglang-router-ip 127.0.0.1 # External router IP
--sglang-router-port 8080 # External router port
--router-balance-abs-threshold 0 # Force balanced distribution
If --sglang-router-ip is not set, slime starts an internal router automatically.
Advanced Features
Dynamic Sampling
Oversample and filter to improve data quality:
ROLLOUT_ARGS += (
--over-sampling-batch-size 64 # Oversample prompts
--dynamic-sampling-filter-path \
slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
)
Example filter function:
def check_reward_nonzero_std ( args , samples : list[Sample], ** kwargs ):
rewards = [sample.get_reward_value(args) for sample in samples]
keep = torch.tensor(rewards, dtype = torch.float).std() > 0.0
return DynamicFilterOutput( keep = keep, reason = None if keep else "zero_std" )
Partial Rollout
Cache and reuse partial generations:
ROLLOUT_ARGS += (
--partial-rollout # Enable partial rollout caching
--buffer-filter-path \
slime.rollout.buffer_hub.buffer_filters.pop_first
)
bf16 Training + fp8 Inference
Reduce inference memory while maintaining training precision:
--hf-checkpoint /path/to/model-FP8 # Use FP8 Hugging Face checkpoint
--ref-load /path/to/bf16_torch_dist # Training stays bf16
The training checkpoint must remain bf16. Only the inference checkpoint uses fp8.
Custom Generation Function
Replace default generation logic:
--custom-generate-function-path mymodule.generate:generate
Implement the function:
async def generate ( args , sample : Sample, sampling_params ) -> Sample:
# Custom generation logic
# Must set: sample.tokens, sample.response_length, sample.status
return sample
Multi-Turn and Agent Training
For complex agent scenarios:
CUSTOM_ARGS = (
--custom-generate-function-path myagent.multiturn:generate
--custom-rm-path myagent.multiturn:reward_func
--metadata-key metadata # Pass tool definitions, etc.
)
See the Quick Start guide’s “Multiturn Adaptation” section for implementing agent training with tool calling.
Complete Example
Here’s a complete training command with all major parameter groups:
python train.py \
# Cluster resources
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
--rollout-num-gpus-per-engine 2 \
\
# Model configuration (loaded from script)
${ MODEL_ARGS [ @ ]} \
\
# Checkpoints
--hf-checkpoint /root/GLM-Z1-9B-0414 \
--ref-load /root/GLM-Z1-9B-0414_torch_dist \
--load /root/GLM-Z1-9B-0414_slime/ \
--save /root/GLM-Z1-9B-0414_slime/ \
--save-interval 20 \
\
# Rollout
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl \
--input-key prompt \
--label-key label \
--apply-chat-template \
--rollout-shuffle \
--rm-type deepscaler \
--num-rollout 3000 \
--rollout-batch-size 16 \
--n-samples-per-prompt 8 \
--num-steps-per-rollout 1 \
--global-batch-size 128 \
--rollout-max-response-len 8192 \
--rollout-temperature 1 \
--balance-data \
\
# Evaluation
--eval-interval 5 \
--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl \
--n-samples-per-eval-prompt 16 \
--eval-max-response-len 16384 \
--eval-top-p 1 \
\
# Parallelism
--tensor-model-parallel-size 2 \
--sequence-parallel \
--pipeline-model-parallel-size 1 \
--context-parallel-size 2 \
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers 1 \
--use-dynamic-batch-size \
--max-tokens-per-gpu 4608 \
\
# Algorithm (GRPO)
--advantage-estimator grpo \
--use-kl-loss \
--kl-loss-coef 0.00 \
--kl-loss-type low_var_kl \
--entropy-coef 0.00 \
--eps-clip 0.2 \
--eps-clip-high 0.28 \
\
# Optimizer
--optimizer adam \
--lr 1e-6 \
--lr-decay-style constant \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.98
Environment Variables
Sometimes environment variables are needed for specific scenarios:
# Multi-node network configuration
export SLIME_HOST_IP = $( hostname -I | awk '{print $1}' )
export GLOO_SOCKET_IFNAME = eth0
export NCCL_SOCKET_IFNAME = eth0
# Megatron PYTHONPATH
export PYTHONPATH = / root / Megatron-LM : $PYTHONPATH
Next Steps
Quick Start See a complete working example
API Reference Explore detailed API documentation
Custom Functions Write custom generation and reward logic
Multi-Node Training Scale to hundreds of GPUs