Skip to main content
Reproducibility is a cornerstone of scientific progress. By combining SGLang’s deterministic inference with Megatron-LM’s deterministic mode, slime can provide completely deterministic (bitwise) experiment reproduction.

Enabling Deterministic Training

To enable fully deterministic training, you need to configure both SGLang (for rollout) and Megatron (for training) to use deterministic operations.

Prerequisites

You must uninstall Flash Attention 3 to enable deterministic mode:
pip uninstall flash_attn_3 -y
Flash Attention 3 uses non-deterministic algorithms for performance. Deterministic training requires using FlashInfer or standard attention backends instead.

Configuration Flags

Add these flags to your training configuration:
# SGLang configuration for deterministic rollout
SGLANG_ARGS=(
  --sglang-enable-deterministic-inference
  --sglang-attention-backend flashinfer
)

# Megatron configuration for deterministic training
TRAINING_ARGS=(
  --deterministic-mode
)

Environment Variables

Set these environment variables to ensure deterministic operations throughout the stack:
RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"NCCL_ALGO\": \"Ring\",
    \"NVTE_ALLOW_NONDETERMINISTIC_ALGO\": \"0\",
    \"CUBLAS_WORKSPACE_CONFIG\": \":4096:8\"
  }
}"
Environment Variable Explanations:
VariablePurposeValue
NCCL_ALGOForce NCCL to use Ring algorithmRing
NVTE_ALLOW_NONDETERMINISTIC_ALGODisable non-deterministic ops in TransformerEngine0
CUBLAS_WORKSPACE_CONFIGEnable deterministic cuBLAS operations:4096:8
The CUBLAS_WORKSPACE_CONFIG format is :size:count. The value :4096:8 allocates 8 workspaces of 4096 bytes each for deterministic algorithms. This may need adjustment for very large models.

Complete Example: GSM8K with Qwen2.5-0.5B

We provide a fully deterministic training example using Qwen2.5-0.5B on GSM8K. This example demonstrates bitwise reproducible training.

Setup

1

Download Data and Model

# Download GSM8K dataset
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/gsm8k

# Download Qwen2.5-0.5B-Instruct
hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/Qwen2.5-0.5B-Instruct
2

Convert Checkpoint

cd slime/
source scripts/models/qwen2.5-0.5B.sh

PYTHONPATH=/root/Megatron-LM/ python \
   tools/convert_hf_to_torch_dist.py \
   ${MODEL_ARGS[@]} \
   --hf-checkpoint /root/Qwen2.5-0.5B-Instruct \
   --save /root/Qwen2.5-0.5B-Instruct_torch_dist/
3

Uninstall Flash Attention 3

pip uninstall flash_attn_3 -y
4

Run Training

bash scripts/run-qwen2.5-0.5B-reproducibility.sh

Training Script

Here’s the complete training script with all deterministic settings:
run-qwen2.5-0.5B-reproducibility.sh
#!/bin/bash

set -ex

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/scripts/models/qwen2.5-0.5B.sh"

CKPT_ARGS=(
   --hf-checkpoint /root/Qwen2.5-0.5B-Instruct
   --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/
   --load /root/Qwen2.5-0.5B-Instruct_torch_dist/
   --save /root/qwen2.5-0.5B-reproducibility/
   --save-interval 10
)

ROLLOUT_ARGS=(
   --prompt-data /root/gsm8k/gsm8k.jsonl
   --input-key question
   --label-key answer
   --apply-chat-template
   --rollout-shuffle
   --rm-type exact_match
   
   --num-rollout 100
   --rollout-batch-size 32
   --n-samples-per-prompt 8
   --rollout-max-response-len 512
   --rollout-temperature 1.0
   
   --global-batch-size 256
   --balance-data
)

PERF_ARGS=(
   --tensor-model-parallel-size 1
   --pipeline-model-parallel-size 1
   --micro-batch-size 8
)

GRPO_ARGS=(
   --advantage-estimator grpo
   --use-kl-loss
   --kl-loss-coef 0.01
)

OPTIMIZER_ARGS=(
   --optimizer adam
   --lr 1e-6
   --lr-decay-style constant
   --weight-decay 0.1
)

SGLANG_ARGS=(
   # Deterministic inference configuration
   --sglang-enable-deterministic-inference
   --sglang-attention-backend flashinfer
   
   --rollout-num-gpus-per-engine 1
   --sglang-mem-fraction-static 0.7
)

MISC_ARGS=(
   # Deterministic training configuration
   --deterministic-mode
   
   --attention-dropout 0.0
   --hidden-dropout 0.0
   --attention-backend flash
)

# Launch Ray cluster
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 \
  --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265

# Build runtime environment with deterministic settings
RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"PYTHONPATH\": \"/root/Megatron-LM/\",
    \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
    \"NCCL_ALGO\": \"Ring\",
    \"NVTE_ALLOW_NONDETERMINISTIC_ALGO\": \"0\",
    \"CUBLAS_WORKSPACE_CONFIG\": \":4096:8\"
  }
}"

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json="${RUNTIME_ENV_JSON}" \
   -- python3 train.py \
   --actor-num-nodes 1 \
   --actor-num-gpus-per-node 8 \
   --colocate \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]} \
   ${ROLLOUT_ARGS[@]} \
   ${OPTIMIZER_ARGS[@]} \
   ${GRPO_ARGS[@]} \
   ${PERF_ARGS[@]} \
   ${SGLANG_ARGS[@]} \
   ${MISC_ARGS[@]}

Verification

The reproducibility example includes WandB logging to verify bitwise determinism. Running the same script multiple times should produce identical:
  • Loss curves
  • Reward curves
  • Model outputs
  • Checkpoint weights
See pull#370 for WandB screenshots demonstrating bitwise reproducibility.
For true bitwise reproducibility, you must also fix:
  • Random seeds (automatically handled by slime)
  • Dataset order (use --rollout-shuffle consistently)
  • Number of workers and GPUs
  • CUDA version and driver version

Performance Impact

Deterministic mode has some performance trade-offs:

Throughput Reduction

Expect 10-30% reduction in throughput compared to non-deterministic mode, primarily due to:
  • FlashInfer vs Flash Attention 3
  • Ring NCCL algorithm
  • Deterministic cuBLAS operations

Memory Overhead

Deterministic operations require additional workspace memory:
  • cuBLAS workspaces: ~32KB (configured via CUBLAS_WORKSPACE_CONFIG)
  • Minimal impact on large models

What Is Deterministic?

In deterministic mode, the following are guaranteed to be bitwise identical across runs:

Guaranteed Deterministic

  • Matrix multiplications (GEMM)
  • Attention operations
  • Activation functions
  • Optimizer updates
  • Gradient computations
  • NCCL collectives
  • Random number generation (with fixed seed)

Not Deterministic

The following may still introduce non-determinism:
  • Multi-process race conditions: If your code has race conditions
  • External services: Calls to external APIs or services
  • File system operations: Order of file reads in multi-threaded contexts
  • Hardware failures: ECC errors or other hardware issues

Advanced Configuration

Custom cuBLAS Workspace Size

For very large models, you may need to increase the cuBLAS workspace size:
# Default: :4096:8 (8 workspaces of 4KB each)
export CUBLAS_WORKSPACE_CONFIG=":16384:8"  # 8 workspaces of 16KB each

Deterministic Data Loading

Ensure data loading is deterministic:
ROLLOUT_ARGS=(
  --rollout-shuffle  # Use consistent shuffling with fixed seed
  --seed 42          # Set explicit random seed
)

Multi-Node Determinism

For multi-node training, ensure:
  1. All nodes use the same CUDA version
  2. All nodes use the same NCCL version
  3. Network topology is consistent
  4. Nodes are synchronized (NTP)
# Verify NCCL version consistency
ray exec /root/ray_cluster.yaml "python -c 'import torch; print(torch.cuda.nccl.version())'"

# Verify CUDA version consistency
ray exec /root/ray_cluster.yaml "nvcc --version"

Debugging Non-Determinism

If you encounter non-determinism despite enabling deterministic mode:

1. Check Flash Attention

python -c "import flash_attn; print(flash_attn.__version__)" 2>/dev/null || echo "Not installed"
If Flash Attention 3 is installed, uninstall it:
pip uninstall flash_attn_3 -y

2. Verify Environment Variables

import os
print("NCCL_ALGO:", os.environ.get("NCCL_ALGO"))
print("NVTE_ALLOW_NONDETERMINISTIC_ALGO:", os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO"))
print("CUBLAS_WORKSPACE_CONFIG:", os.environ.get("CUBLAS_WORKSPACE_CONFIG"))

3. Check PyTorch Determinism

import torch
print("PyTorch deterministic:", torch.are_deterministic_algorithms_enabled())
print("cuDNN deterministic:", torch.backends.cudnn.deterministic)
print("cuDNN benchmark:", torch.backends.cudnn.benchmark)
All should be True, True, False respectively in deterministic mode.

4. Compare Checksums

After training, compare checkpoint checksums:
# Run 1
md5sum /root/qwen2.5-0.5B-reproducibility/iter_0000100/model_optim_rng.pt

# Run 2 (should match)
md5sum /root/qwen2.5-0.5B-reproducibility/iter_0000100/model_optim_rng.pt

Best Practices

Enable deterministic mode for:
  • Paper results that need exact reproduction
  • Debugging training instabilities
  • Ablation studies requiring precise comparison
Disable for routine training where performance is more important.
When publishing results, document:
  • CUDA version: nvcc --version
  • PyTorch version: python -c 'import torch; print(torch.__version__)'
  • Slime commit hash: git rev-parse HEAD
  • GPU type: nvidia-smi --query-gpu=name --format=csv,noheader
This helps others reproduce your results.
Run a short deterministic training test early in your project:
# Run 1
bash scripts/run-qwen2.5-0.5B-reproducibility.sh
cp /root/qwen2.5-0.5B-reproducibility/iter_0000010/model_optim_rng.pt /tmp/run1.pt

# Run 2
bash scripts/run-qwen2.5-0.5B-reproducibility.sh
cp /root/qwen2.5-0.5B-reproducibility/iter_0000010/model_optim_rng.pt /tmp/run2.pt

# Compare
diff /tmp/run1.pt /tmp/run2.pt && echo "Deterministic!" || echo "Non-deterministic!"
Consider a hybrid approach:
  • Use deterministic mode for final benchmark runs
  • Use non-deterministic mode for development and hyperparameter tuning
  • Re-run important experiments in deterministic mode to verify

Future Improvements

Planned enhancements to reproducibility in slime:
  • Automatic detection of non-deterministic operations with warnings
  • Per-component determinism flags (e.g., deterministic rollout but non-deterministic training)
  • Better performance optimization for deterministic mode
  • Reproducibility verification tests in CI/CD

Build docs developers (and LLMs) love