Skip to main content

Overview

While Megatron-LM is highly efficient for parallel training, it can lack the flexibility to support rapidly evolving model architectures. slime provides two approaches to handle cutting-edge models:
  1. HuggingFace Module Wrapping - Import and wrap official HF implementations into Megatron’s pipeline
  2. FSDP Backend - Use PyTorch’s Fully Sharded Data Parallel for maximum flexibility
This guide covers both approaches.

Approach 1: HuggingFace Module Wrapping

Instead of deeply re-engineering Megatron, slime can directly import and wrap a model’s official HuggingFace implementation, embedding it as a “black-box” module into Megatron’s parallel training pipeline.

How It Works

Megatron’s model instantiation is a two-step process:
  1. Generate a “layer specification” (ModuleSpec) based on configuration
  2. Instantiate actual PyTorch modules according to that spec
slime hijacks the spec generation stage to replace Megatron’s native modules with external implementations.

Core Components

1

Replace the Megatron Module Spec

Use a custom function (e.g., get_qwen3_next_spec) to modify the standard ModuleSpec:
  • Retrieve the standard Decoder Block Spec
  • Point its self_attention field to a custom wrapper module
  • Enable model-specific configurations like qk_layernorm
Implementation: slime_plugins/models/qwen3_next.py
2

Wrap the HuggingFace Implementation

The modified spec points to a wrapper layer (e.g., HuggingfaceAttention) that:
  • Inherits from Megatron’s MegatronModule
  • Handles data alignment for parallelism strategies (like sequence parallelism)
  • Internally calls the native attention module loaded from HuggingFace
Implementation: slime_plugins/models/hf_attention.py
3

Align Model Weights

Use the mbridge library to establish a naming map between HuggingFace checkpoints and Megatron parameters:
  • Enables seamless bidirectional conversion
  • Handles parameter name mapping automatically
Implementation: slime_plugins/mbridge/qwen3_next.py

Example: Qwen3Next 80B-A3B

# Custom spec generator for Qwen3Next
def get_qwen3_next_spec():
    spec = get_gpt_layer_with_transformer_engine_spec()
    # Replace attention with HuggingFace implementation
    spec.self_attention = HuggingfaceAttention
    spec.qk_layernorm = True
    return spec

Capabilities

With this approach, you can run complex model architectures (like Gated-Delta-Net) while retaining:
  • Model parallelism (PP, EP)
  • MoE acceleration
  • Pipeline scheduling
  • Sequence parallelism

Current Limitations

Tensor Parallelism Not Supported: This approach does not currently support Tensor Parallelism (TP) within the replaced module itself (e.g., the Attention layer).Impact: In most large-scale MoE models, the parameter count of the Attention layer is relatively small, so this limitation typically has minimal effect on memory footprint and training throughput.Alternative: If TP for the module is critical, you must revert to modifying Megatron’s native implementation.

Approach 2: FSDP Backend

For maximum flexibility with modern architectures, slime provides a native FSDP (Fully Sharded Data Parallel) backend that works directly with HuggingFace models.

Architecture

The FSDP backend (FSDPTrainRayActor) provides:
  • Direct HuggingFace model loading
  • PyTorch FSDP2 for efficient distributed training
  • Optional CPU offloading for large models
  • Full compatibility with slime’s RL training pipeline

Key Features

Native HF Support

Load models directly from HuggingFace without conversion

Memory Efficient

CPU offloading and mixed precision support

Flexible Parallelism

Data parallelism with FSDP2 sharding strategies

RL Compatible

Full integration with slime’s PPO/GRPO algorithms

Configuration

The FSDP backend supports several configuration options:
@dataclass
class FSDPArgs:
    # Optimizer configuration
    optimizer: str = "adam"
    lr: float = 2e-5
    weight_decay: float = 0.0
    adam_beta1: float = 0.9
    adam_beta2: float = 0.95
    adam_eps: float = 1e-8
    
    # Memory management
    gradient_checkpointing: bool = False
    fsdp_cpu_offload: bool = False  # Offload params/grads/optimizer to CPU
    fsdp_state_dict_cpu_offload: bool = True  # Offload checkpoints to CPU
    
    # Precision
    fp16: bool = False  # Use FP16 (default: BF16)
    attn_implementation: str = "flash_attention_2"

Usage Example

from slime.backends.fsdp_utils import FSDPTrainRayActor, fsdp_parse_args

# Parse arguments with FSDP defaults
args = fsdp_parse_args()

# Initialize FSDP actor
actor = FSDPTrainRayActor()
actor.init(
    args=args,
    role="actor",
    with_ref=True  # Create separate reference model
)

FSDP Implementation Details

Model Initialization

The FSDP backend uses an efficient initialization strategy:
def _get_init_weight_context_manager(self):
    """Rank 0: CPU, others: meta device for memory efficiency."""
    from accelerate import init_empty_weights
    
    use_meta_tensor = not self.hf_config.tie_word_embeddings
    
    if use_meta_tensor:
        return init_empty_weights if dist.get_rank() != 0 else cpu_init_weights
    else:
        return cpu_init_weights
Memory Optimization: Non-rank-0 processes use meta tensors (no memory allocation) unless tie_word_embeddings=True, which requires full CPU initialization to avoid hangs.

Data Packing

FSDP backend includes efficient sequence packing:
packed_batches = pack_sequences(
    tokens=rollout_data["tokens"],
    loss_masks=rollout_data["loss_masks"],
    rewards=rollout_data["rewards"],
    response_lengths=rollout_data["response_lengths"],
    num_packs=num_microbatches
)
This maximizes GPU utilization by packing variable-length sequences into fixed-size batches.

Reference Model Management

For PPO/GRPO training, the FSDP backend maintains a separate reference model:
def _create_ref_model(self, ref_load_path: str):
    """Create FSDP2-wrapped reference model with CPU offload."""
    ref_model = self.get_model_cls().from_pretrained(
        ref_load_path,
        trust_remote_code=True,
        attn_implementation=self.args.attn_implementation
    )
    
    # Always use CPU offload for reference model to save memory
    ref_model = apply_fsdp2(
        ref_model, 
        mesh=self.dp_mesh, 
        cpu_offload=True, 
        args=self.args
    )
    
    return ref_model

Training Workflow

The FSDP training loop follows slime’s standard RL training pipeline:
1

Process Rollout Data

Partition rollout data across data-parallel ranks:
rollout_data = process_rollout_data(
    args, rollout_data_ref, 
    self.dp_rank, self.dp_size
)
2

Compute Advantages

Calculate advantages using GRPO/GSPO:
rollout_data["advantages"] = rollout_data["returns"] = [
    torch.tensor([rewards[i]] * response_lengths[i])
    for i in range(len(rewards))
]
3

Pack Sequences

Pack variable-length sequences for efficient processing:
packed_batches, grad_accum = self._packed_data(rollout_data)
4

Forward Pass

Compute log probabilities for actor and reference:
self._compute_log_prob("ref", packed_batches, store_prefix="ref_")
self._compute_log_prob("actor", packed_batches)
5

Backward and Optimize

Compute policy loss and update weights:
pg_loss, pg_clipfrac = compute_policy_loss(
    ppo_kl, advantages, 
    args.eps_clip, args.eps_clip_high
)
loss.backward()
optimizer.step()

Advantages Over Megatron

FeatureFSDP BackendMegatron-LM
HF Model SupportNative, no conversionRequires torch_dist conversion
New ArchitecturesImmediate supportManual implementation required
Tensor ParallelismData parallel onlyTP, PP, EP support
Memory EfficiencyCPU offload, FSDP2 shardingGradient checkpointing, ZeRO
Development SpeedFast prototypingProduction-grade performance

When to Use FSDP

Choose the FSDP backend when:
  • Working with cutting-edge model architectures (Qwen3Next, Gemma2, etc.)
  • Rapid prototyping and experimentation
  • Models fit within data-parallel scaling limits
  • You need native HuggingFace compatibility
Choose Megatron-LM when:
  • Training massive models requiring tensor parallelism
  • Maximum training throughput is critical
  • Using well-supported architectures (GPT, LLaMA, Qwen)
  • Production deployment at scale

Supported Backends Comparison

Megatron-LM

Best for: Production training
  • Full 3D parallelism (TP/PP/DP)
  • Maximum throughput
  • Battle-tested at scale

HF Wrapping

Best for: New architectures with Megatron
  • Custom attention mechanisms
  • Partial Megatron parallelism
  • Quick integration

FSDP

Best for: Flexible experimentation
  • Native HF models
  • Fast iteration
  • Data-parallel scaling

Getting Started

Using HF Wrapping

  1. Implement custom spec generator in slime_plugins/models/
  2. Create HF wrapper module in slime_plugins/models/hf_attention.py
  3. Add weight mapping in slime_plugins/mbridge/

Using FSDP Backend

# Install slime with FSDP support
pip install -e .

# Run training with FSDP
python train.py \
  --backend fsdp \
  --hf-checkpoint Qwen/Qwen3-4B \
  --fsdp-cpu-offload \
  --gradient-checkpointing \
  --lr 1e-5

Additional Resources

Build docs developers (and LLMs) love