Skip to main content
Low precision training enables efficient RL scaling by reducing memory footprint and increasing throughput during both training and inference. Slime supports multiple quantization strategies including FP8 and INT4 quantization-aware training (QAT).

FP8 Rollout with BF16 Training

You can run FP8 rollout while keeping training in BF16 format by using a blockwise quantized checkpoint for inference.

Converting Models to FP8

Convert your BF16 model to FP8 format using the conversion tool:
python tools/convert_hf_to_fp8.py \
    --model-dir $BF16_MODEL \
    --save-dir $FP8_model \
    --strategy block --block-size 128 128 \
    --max-workers 4
Ensure the converted checkpoint’s config.json contains the correct quantization_config so that slime can automatically use FP8 quantization during weight updates.
The FP8 checkpoint is only used for rollout inference. Training weights remain in BF16 format.

FP8 Training and Inference

For maximum efficiency and training stability, you can use FP8 for both training and inference. This achieves:
  • More efficient inference throughput
  • Lower training-inference mismatch
  • More stable training dynamics
See the FP8 RL blog post for detailed analysis and results.

Quick Start

1

Convert Model to FP8

Use the conversion tool from the previous section:
python tools/convert_hf_to_fp8.py \
    --model-dir /path/to/bf16/model \
    --save-dir /path/to/fp8/model \
    --strategy block --block-size 128 128 \
    --max-workers 4
2

Configure Training Flags

Add these flags to your training script:
--fp8-format e4m3
--fp8-recipe blockwise
# --fp8-param-gather  # Optional: Currently incompatible with CPU Adam
Enable the required environment variable:
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
3

Start Training

Launch training with one of the provided scripts:
# Qwen3-4B FP8 training
bash scripts/low_precision/run-qwen3-4b-fp8.sh

# Qwen3-30B-A3B FP8 training (2 nodes)
bash scripts/low_precision/run-qwen3-30b-a3b-fp8.sh
4

Use Saved Checkpoints

Note that TransformerEngine saves weights in the original precision (usually BF16), not FP8. To evaluate under FP8:
  1. Convert checkpoint from torch_dist to HuggingFace format
  2. Convert HuggingFace checkpoint to FP8 format using convert_hf_to_fp8.py

Implementation Details

Here’s how FP8 training works in slime:
  1. Initialization: When FP8 recipe is enabled, layers are built in FP8 context
  2. Training: Weights and activations are quantized online to nvfp8 format. cuBLAS FP8 GEMM is called for GEMM computations in forward and backward passes
  3. Weight Updates: During RL weight updates, Megatron dequantizes FP8 weights to BF16, then slime quantizes them back to FP8 for sending to SGLang
  4. Checkpoint Saving: Checkpoints are dequantized to BF16 and saved in torch_dist format
Only Linear and GroupLinear layers in TransformerEngine use FP8 format. The embedding and lm_head layers remain in their original precision. If --fp8-param-gather is not enabled, weights in TransformerEngine remain in BF16 format and are only cast to FP8 during GEMM operations.

Known Limitations

FP8 training has the following known issues:
  • FP8 weights (--fp8-param-gather) provide memory savings but currently require TransformerEngine’s FusedAdam, which conflicts with the commonly used Adam CPU offload technique in Megatron-LM
  • The additional dequantization/quantization during weight updates is not elegant but necessary for framework compatibility

INT4 QAT Training

INT4 quantization-aware training (QAT) uses Straight-Through Estimator (STE) to enable training with INT4 inference. This significantly improves throughput during the rollout generation phase.

Quick Start

1

Convert HuggingFace Weights to INT4

Use the direct conversion script:
python tools/convert_hf_to_int4_direct.py \
    --model-dir /path/to/your/original/models \
    --save-dir /path/to/your/save/models
Ensure config.json contains the correct quantization_config for automatic INT4 quantization during weight updates.
2

Configure Environment Variables

Set the required environment variables for quantization:
  • OPEN_TRAINING_INT4_FAKE_QAT_FLAG: Enables fake quantization operations for INT4 training
  • OPEN_TRAINING_INT4_GROUP_SIZE: Specifies the block size (group size) for model quantization
Group Size Guidelines:
  • Set to 128 for moonlight-16B-A3B, qwen3-30B-A3B, and qwen3-235B-A22B-int4
  • Set to 32 for kimi-k2-Thinking-int4
Example configuration:
RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\",
    \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\"
  }
}"
3

Launch Training

Use one of the provided training scripts:
# Moonlight-16B-A3B Int4 training
bash scripts/low_precision/run-moonlight-16B-A3B-int4.sh

# Qwen3-30B-A3B Int4 training
bash scripts/low_precision/run-qwen3-30B-A3B-int4.sh

# Qwen3-235B-A22B Int4 training (8 nodes)
bash scripts/low_precision/run-qwen3-235B-A22B-int4.sh

# Kimi-k2-Thinking Int4 training (32 nodes)
bash scripts/low_precision/run-kimi-k2-Thinking-int4.sh
For multi-node environments, start the Ray service according to your cluster configuration before launching training.

INT4 Rollout Only

If you only want INT4 inference during rollout without QAT training, simply set --hf-checkpoint to the converted INT4 checkpoint. No additional environment variables are needed.

Example Configuration

Here’s a complete example configuration for FP8 training with Qwen3-4B:
run-qwen3-4b-fp8.sh
CKPT_ARGS=(
   --hf-checkpoint /root/Qwen3-4B-FP8
   --ref-load /root/Qwen3-4B_torch_dist
   --load /root/qwen3-4b_cp8_fp8
   --save /root/rl-model/qwen3-4b_cp8_fp8
   --save-interval 20
)

PERF_ARGS=(
   --tensor-model-parallel-size 2
   --sequence-parallel
   --pipeline-model-parallel-size 1
   --context-parallel-size 1
)

# Enable FP8 block scaling with FP32 scales
RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"NVTE_FP8_BLOCK_SCALING_FP32_SCALES\": \"1\"
  }
}"

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json="${RUNTIME_ENV_JSON}" \
   -- python3 train.py \
   --fp8-format e4m3 \
   --fp8-recipe blockwise \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]} \
   ${PERF_ARGS[@]}
And for INT4 training with Qwen3-30B-A3B:
run-qwen3-30B-A3B-int4.sh
CKPT_ARGS=(
   --hf-checkpoint /root/Qwen3-30B-A3B-INT4/
   --ref-load /root/Qwen3-30B-A3B_torch_dist/
)

RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"OPEN_TRAINING_INT4_FAKE_QAT_FLAG\": \"1\",
    \"OPEN_TRAINING_INT4_GROUP_SIZE\": \"128\"
  }
}"

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json="${RUNTIME_ENV_JSON}" \
   -- python3 train.py \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]}

Best Practices

Choose the Right Precision

  • Use FP8 rollout + BF16 training for a simple efficiency boost
  • Use FP8 training + FP8 inference for maximum throughput and stability
  • Use INT4 QAT for the largest models when memory is constrained

Monitor Training Stability

  • Watch for divergence when switching to lower precision
  • FP8 typically provides better stability than INT4
  • Adjust learning rate if needed when changing precision

Build docs developers (and LLMs) love