slime is built on Ray for distributed execution, enabling training of large models across multiple nodes. This guide covers Ray cluster setup, multi-node configuration, and optimization for large-scale MOE models.
Overview
slime’s distributed architecture:
Ray Cluster : Manages resources and job scheduling across nodes
Training (Actor) : Megatron-based model training
Inference (Rollout) : SGLang-based response generation
Coordination : Ray handles communication and synchronization
Single Node Setup
For single-node training (one machine with multiple GPUs):
Start Ray Head Node
#!/bin/bash
# Start Ray head node
export MASTER_ADDR = "127.0.0.1"
ray start --head \
--node-ip-address ${ MASTER_ADDR } \
--num-gpus 8 \
--disable-usage-stats \
--dashboard-host=0.0.0.0 \
--dashboard-port=8265
Submit Training Job
# Build runtime environment
RUNTIME_ENV_JSON = "{
\" env_vars \" : {
\" PYTHONPATH \" : \" /root/Megatron-LM/ \" ,
\" CUDA_DEVICE_MAX_CONNECTIONS \" : \" 1 \" ,
\" NCCL_NVLS_ENABLE \" : \" 0 \"
}
}"
# Submit job
ray job submit --address= "http://127.0.0.1:8265" \
--runtime-env-json= "${ RUNTIME_ENV_JSON }" \
-- python3 train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--rollout-num-gpus 8 \
${ MODEL_ARGS [ @ ]} \
${ CKPT_ARGS [ @ ]} \
${ ROLLOUT_ARGS [ @ ]} \
${ OPTIMIZER_ARGS [ @ ]} \
${ GRPO_ARGS [ @ ]}
Resource Allocation
Number of nodes for training actor.
--actor-num-gpus-per-node
GPUs per node allocated to training.
Total GPUs for inference. Ignored when using --colocate.
Total GPUs available per node. Important when using --colocate with fewer than 8 GPUs.
Multi-Node Setup
For training large models across multiple machines:
Step 1: Start Ray Cluster
On Head Node (Node 0)
#!/bin/bash
# Set the master address
export MASTER_ADDR = "192.168.1.100" # Replace with actual IP
# Start Ray head node
ray start --head \
--node-ip-address ${ MASTER_ADDR } \
--num-gpus 8 \
--disable-usage-stats \
--dashboard-host=0.0.0.0 \
--dashboard-port=8265
On Worker Nodes (Node 1, 2, …)
#!/bin/bash
# Master address from head node
export MASTER_ADDR = "192.168.1.100"
# Start Ray worker
ray start \
--address= "${ MASTER_ADDR }:6379" \
--num-gpus 8
Step 2: Verify Cluster
Check cluster status:
# On head node
ray status
Expected output:
Resources
---------------------------------------------------------------
Usage:
0.0/64.0 CPU
0B/214.71GiB memory
0B/107.36GiB object_store_memory
0.00/8.00 GPU
Step 3: Submit Multi-Node Job
#!/bin/bash
ray job submit --address= "http://127.0.0.1:8265" \
--runtime-env-json= '{
"env_vars": {
"PYTHONPATH": "/root/Megatron-LM/",
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"MASTER_ADDR": "' ${ MASTER_ADDR } '"
}
}' \
-- python3 train.py \
--actor-num-nodes 8 \
--actor-num-gpus-per-node 8 \
--colocate \
${ MODEL_ARGS [ @ ]} \
${ CKPT_ARGS [ @ ]} \
${ ROLLOUT_ARGS [ @ ]} \
${ OPTIMIZER_ARGS [ @ ]} \
${ GRPO_ARGS [ @ ]}
Colocated Mode
Colocated mode runs training and inference on the same GPUs, saving resources.
Configuration
ray job submit ... \
-- python3 train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 8 \
--colocate \
--sglang-mem-fraction-static 0.8 \
...
Memory Management
In colocated mode, Megatron allocates GPU memory before offloading. Reduce SGLang’s memory fraction to prevent OOM:
# Recommended for colocated mode
SGLANG_ARGS = (
--sglang-mem-fraction-static 0.8
)
When to Use Colocated Mode
Use Colocated
Limited GPU resources
Small to medium models
Training throughput is priority
Use Disaggregated
Large GPU clusters available
Maximum inference throughput needed
Independent scaling of training/inference
Network Configuration
Environment Variables
For multi-node setups, you may need to configure network interfaces:
# Detect and set network interface
export SLIME_HOST_IP = $( hostname -I | awk '{print $1}' )
export GLOO_SOCKET_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}' )
export NCCL_SOCKET_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}' )
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}' )
SLURM + Enroot Example
For SLURM clusters with enroot containers:
#!/bin/bash
# Auto-detect network configuration
export SLIME_HOST_IP = $( hostname -I | awk '{print $1}' )
export GLOO_SOCKET_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}' )
export NCCL_SOCKET_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}' )
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME = $( ip -o -4 addr show | awk '$4 ~ /^10\. / {print $2}' )
# Start Ray cluster
export MASTER_ADDR = ${ MLP_WORKER_0_HOST }
ray start --head --node-ip-address ${ MASTER_ADDR } --num-gpus 8
Large-Scale MOE Models
slime is optimized for training massive Mixture of Experts models.
Example: GLM-4.5 (355B, 32 experts)
Training configuration for 64xH100 GPUs (8 nodes × 8 GPUs):
Cluster Setup
Parallelism Configuration
SGLang Configuration
Optimizer Configuration
MOE Optimizations
Full Training Script
#!/bin/bash
# On head node
export MASTER_ADDR = ${ MLP_WORKER_0_HOST }
ray start --head \
--node-ip-address ${ MASTER_ADDR } \
--num-gpus 8 \
--disable-usage-stats
# Start workers on other nodes
for WORKER_IP in $( awk '{print $1}' /root/mpi_rack_hostfile ); do
if [[ " $WORKER_IP " == " $MLP_WORKER_0_HOST " ]]; then
continue
fi
echo "Starting Ray worker on ${ WORKER_IP }"
ssh root@"${ WORKER_IP }" \
"ray stop --force; \
ray start --address=${ MASTER_ADDR }:6379 \
--num-gpus 8 \
--node-ip-address ${ WORKER_IP }" &
done
wait
Parallelism Strategy
For large MOE models, carefully tune parallelism:
# Total GPUs = TP × PP × EP × DP
# For 64 GPUs:
TP = 8 # Tensor parallelism
PP = 4 # Pipeline stages
EP = 16 # Expert parallelism
CP = 2 # Context parallelism
# Verify: 8 × 4 × 2 = 64 GPUs
Rule of thumb : Start with high TP for large models, then add PP for very deep models, and EP for MOE layers. Use CP only for long sequences.
Advanced NCCL Tuning
For optimal multi-node performance:
export NCCL_CUMEM_ENABLE = 0
export NCCL_IB_TC = 160
export NCCL_PXN_DISABLE = 0
export NCCL_IB_GID_INDEX = 3
export NCCL_NET_GDR_LEVEL = 4
export NCCL_IB_RETRY_CNT = 7
export NCCL_IB_TIMEOUT = 32
export NCCL_IB_QPS_PER_CONNECTION = 8
export NCCL_P2P_LEVEL = "NVL"
export TORCH_NCCL_AVOID_RECORD_STREAMS = 1
export NCCL_NVLS_ENABLE = 0
export NCCL_MIN_CTAS = 4
Include these in your runtime environment:
ray job submit ... \
--runtime-env-json= '{
"env_vars": {
"PYTHONPATH": "/root/Megatron-LM/",
"NCCL_IB_TC": "160",
"NCCL_P2P_LEVEL": "NVL",
...
}
}' \
-- python3 train.py ...
Monitoring
Ray Dashboard
Access the Ray dashboard at http://<head-node-ip>:8265:
View cluster resources
Monitor job status
Check GPU utilization
Inspect logs
Weights & Biases Integration
WANDB_ARGS = (
--use-wandb
--wandb-project slime-large-scale
--wandb-group glm45-355b
--wandb-key ${ WANDB_KEY }
)
ray job submit ... \
-- python3 train.py \
${ WANDB_ARGS [ @ ]} \
...
Checkpointing
For large models, use async checkpointing:
CKPT_ARGS = (
--save /shared/checkpoints/model/
--save-interval 20
--async-save # Enable async checkpointing
--no-save-optim # Skip optimizer state to reduce size
)
Troubleshooting
Common Issues
Ray workers not connecting
Symptoms : Workers don’t appear in ray statusSolutions :
Verify firewall allows port 6379
Check MASTER_ADDR is correct and reachable
Ensure same Ray version on all nodes
Try ray stop --force on all nodes and restart
Symptoms : NCCL error: unhandled system errorSolutions :
Increase timeout: --distributed-timeout-minutes 20
Check network interface: Set GLOO_SOCKET_IFNAME and NCCL_SOCKET_IFNAME
Verify InfiniBand configuration if using IB
Reduce --global-batch-size to decrease communication
Out of memory in colocated mode
Symptoms : CUDA OOM during rolloutSolutions :
Reduce --sglang-mem-fraction-static to 0.7 or lower
Decrease --max-tokens-per-gpu
Enable --recompute-granularity full
Use --optimizer-cpu-offload
Symptoms : Long wait times between rolloutsSolutions :
Use shared filesystem (NFS, Lustre) for datasets
Enable --balance-data for better load distribution
Increase --rollout-batch-size to amortize overhead
Check network bandwidth between nodes
Debugging Commands
# Check Ray cluster status
ray status
# List all Ray nodes
ray list nodes
# View job logs
ray job logs < job-i d >
# SSH to worker and check processes
ssh worker-node
nvidia-smi
ps aux | grep python
# Test network connectivity
ping < worker-i p >
iperf3 -c < worker-i p > # Bandwidth test
Prefetching and Pipelining
# Enable asynchronous operations
OPTIMIZER_ARGS = (
--overlap-cpu-optimizer-d2h-h2d
--overlap-grad-reduce
)
Dynamic Sampling with Buffer
ROLLOUT_ARGS = (
# Dynamic sampling for quality
--over-sampling-batch-size 256
--rollout-batch-size 128
--dynamic-sampling-filter-path \
slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
# Partial rollout to reduce waste
--partial-rollout
--buffer-filter-path slime.rollout.filter_hub.buffer_filters.pop_first
)
Load Balancing
ROLLOUT_ARGS = (
--balance-data # Distribute tokens evenly across GPUs
)
Resource Planning
Memory Estimation
Approximate GPU memory requirements:
Model Memory = Parameters × 2 bytes (bf16) × (1 + optimizer_multiplier)
For Adam optimizer:
optimizer_multiplier = 2 (momentum + variance)
Total = Parameters × 2 × 3 = Parameters × 6 bytes
For 355B model:
355B × 6 = 2.1 TB (distributed across GPUs)
With 64 GPUs (80GB each):
Per GPU = 2.1TB / 64 ≈ 33GB
Remaining = 80 - 33 = 47GB for activations and inference
Scaling Guidelines
Model Size Recommended GPUs Parallelism Strategy 7-13B 8 GPUs TP=2, PP=1 30-70B 16-32 GPUs TP=4, PP=2 100-200B 32-64 GPUs TP=8, PP=4 300B+ MOE 64+ GPUs TP=8, PP=4, EP=16
Next Steps
Configuration Guide Review all configuration parameters
Multi-Turn Agents Train agents with tool calling