Overview
This example demonstrates training GLM-4.5-355B-A32B, a large-scale Mixture-of-Experts model, using 64×H100 GPUs across multiple nodes. The configuration uses advanced 4D parallelism (TP8, PP4, CP2, EP16) with CPU Adam optimization and supports optional FP8 inference with deepep.
Model Specifications
- Model: GLM-4.5-355B-A32B (zai-org/GLM-4.5)
- Architecture: Mixture-of-Experts (MoE)
- Parameters: 355 billion
- Hardware: 64×H100 GPUs (8 nodes × 8 GPUs)
- Parallelism: TP8, PP4, CP2, EP16
- Memory: CPU Adam with ~1.4-1.5TB host memory per node
Dataset
Environment Setup
Download model to shared storage
Download GLM-4.5 to a directory accessible by all machines ($BASE_DIR):hf download zai-org/GLM-4.5 --local-dir $BASE_DIR/GLM-4.5-355B-A32B
Convert checkpoint (2 nodes)
Convert the HF checkpoint to torch_dist format using 2 nodes (16 GPUs):cd slime/
source scripts/models/glm4.5-355B-A32B.sh
PYTHONPATH=/root/Megatron-LM/ torchrun \
--nproc-per-node 8 \
--master-addr ${MASTER_ADDR} --master-port 12345 \
--nnodes=2 --node-rank ${NODE_RANK} \
tools/convert_hf_to_torch_dist.py \
${MODEL_ARGS[@]} \
--hf-checkpoint $BASE_DIR/GLM-4.5-355B-A32B/ \
--save $BASE_DIR/GLM-4.5-355B-A32B_torch_dist/
MASTER_ADDR: IP address of node 0
NODE_RANK: Node index (0 for node 0, 1 for node 1)
Training Execution
Start Ray on node 0
On node 0, run:cd slime/
bash scripts/run-glm4.5-355B-A32B.sh
Join Ray cluster on worker nodes
On each worker node, join the Ray cluster:ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 \
--node-ip-address ${WORKER_IP} --disable-usage-stats
Automated Worker Setup
If you have an MPI hostfile (each line: ip slot=8), add this to scripts/run-glm4.5-355B-A32B.sh after ray start --head to automate worker setup:
for WORKER_IP in $(awk '{print $1}' $BASE_DIR/mpi_hostfile); do
if [[ "$WORKER_IP" == "$MASTER_ADDR" ]]; then
continue
fi
echo "Starting Ray worker on ${WORKER_IP}"
ssh root@"${WORKER_IP}" \
"pkill -9 sglang ; ray stop --force ; pkill -9 python ; \
ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 \
--node-ip-address ${WORKER_IP} --disable-usage-stats" &
done
wait
Training Configuration
MODEL_ARGS
Load model configuration from scripts/models/glm4.5-355B-A32B.sh:
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/models/glm4.5-355B-A32B.sh"
These are Megatron parameters defining the model architecture. Megatron cannot read configs from checkpoints, so manual configuration is required.
Advanced Parallelism (PERF_ARGS)
PERF_ARGS=(
--tensor-model-parallel-size 8
--sequence-parallel
--pipeline-model-parallel-size 4
--context-parallel-size 2
--expert-model-parallel-size 16
--expert-tensor-parallel-size 1
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
--use-dynamic-batch-size
--max-tokens-per-gpu 16384
)
4D Parallelism: TP=8, PP=4, CP=2, EP=16 for 64 GPUs. Each GPU processes up to 16,384 tokens with context parallelism.
GRPO Configuration
GRPO_ARGS=(
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
)
To train without a reference model, remove --use-kl-loss and ensure --kl-loss-coef 0.00.
CPU Adam Optimization
OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
--optimizer-cpu-offload
--overlap-cpu-optimizer-d2h-h2d
--use-precision-aware-optimizer
)
CPU Adam saves GPU memory but requires 1.4-1.5TB host memory per node (8×H100). If a single machine lacks sufficient host memory, expand parallelism by adding more GPUs.
SGLang Configuration
SGLANG_ARGS=(
--rollout-num-gpus-per-engine 32
--sglang-mem-fraction-static 0.7
--sglang-enable-dp-attention
--sglang-dp-size 4
)
SGLang uses 32 GPUs per engine with DP attention (DP=4) for efficient inference.
Miscellaneous (MISC_ARGS)
MISC_ARGS=(
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
# Use deepep for Megatron
--moe-enable-deepep
--moe-token-dispatcher-type flex
)
Megatron’s deepep is configured for efficient MoE communication.
FP8 Rollout (Advanced)
The open-source FP8 checkpoint of GLM-4.5 uses per-channel quantization, which doesn’t support deepep in SGLang. Convert to 128×128 per-block quantization:
python tools/convert_hf_to_fp8.py \
--model-dir $BASE_DIR/GLM-4.5-355B-A32B/ \
--save-dir $BASE_DIR/GLM-4.5-355B-A32B-FP8/ \
--strategy block --block-size 128 128 \
--max-workers 4
Then update the checkpoint path:
CKPT_ARGS=(
--hf-checkpoint $BASE_DIR/GLM-4.5-355B-A32B-FP8/
--ref-load $BASE_DIR/GLM-4.5-355B-A32B_torch_dist/
--load $BASE_DIR/GLM-4.5-355B-A32B_slime/
--save $BASE_DIR/GLM-4.5-355B-A32B_slime/
--save-interval 20
)
FP8 SGLang Configuration
SGLANG_ARGS=(
--rollout-num-gpus-per-engine 32
--sglang-mem-fraction-static 0.7
--sglang-enable-dp-attention
--sglang-dp-size 32
--sglang-ep-size 32
--sglang-moe-dense-tp-size 1
--sglang-enable-dp-lm-head
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 128)
--sglang-moe-a2a-backend deepep
--sglang-deepep-mode auto
)
FP8 rollout with deepep significantly improves inference throughput for large MoE models.
Rollout Configuration
ROLLOUT_ARGS=(
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--rm-type deepscaler
--num-rollout 3000
--rollout-batch-size 32
--n-samples-per-prompt 8
--rollout-max-response-len 8192
--rollout-temperature 1
--global-batch-size 256
--balance-data
)
Evaluation
EVAL_ARGS=(
--eval-interval 20
--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
--n-samples-per-eval-prompt 16
--eval-max-response-len 16384
--eval-top-p 1
)
Checkpoint Configuration
CKPT_ARGS=(
# HF checkpoint for SGLang and tokenizer
--hf-checkpoint $BASE_DIR/GLM-4.5-355B-A32B/
# Reference model checkpoint (frozen)
--ref-load $BASE_DIR/GLM-4.5-355B-A32B_torch_dist/
# Actor model checkpoint (if empty, loads from ref-load)
--load $BASE_DIR/GLM-4.5-355B-A32B_slime/
--save $BASE_DIR/GLM-4.5-355B-A32B_slime/
--save-interval 20
)
Dynamic Batch Sizing
--max-tokens-per-gpu 16384 with --context-parallel-size 2 means each CP group shares 32,768 tokens total. Dynamic batch sizing optimizes memory without affecting loss calculation.
Reference