Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/QwenLM/Qwen3-VL/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The trainer.py module extends the Hugging Face Trainer with custom implementations optimized for Qwen-VL models. It includes custom attention mechanisms using Flash Attention 2, monkey-patched forward methods for various Qwen model versions, and a specialized optimizer creation method that supports different learning rates for vision tower and multimodal projector components.

Custom Optimizer Creation

create_optimizer()

Creates an optimizer with separate learning rate configurations for different model components.
self.args.mm_projector_lr
float
default:"None"
Learning rate for the multimodal projector (merger) module. If set, the projector parameters will use this learning rate instead of the base learning rate.
self.args.vision_tower_lr
float
default:"None"
Learning rate for the vision tower (visual encoder) module. If set, the vision tower parameters will use this learning rate instead of the base learning rate.
self.args.weight_decay
float
Weight decay coefficient applied to parameters with decay (excludes bias terms).
Returns: Configured optimizer instance with parameter groups for:
  • Base model parameters (with and without weight decay)
  • Vision tower parameters (with and without weight decay, if vision_tower_lr is set)
  • Multimodal projector parameters (with and without weight decay, if mm_projector_lr is set)

Flash Attention Methods

flash_attention_forward()

Custom Flash Attention 2 forward pass for efficient attention computation.
module
torch.nn.Module
required
The attention module.
query
torch.Tensor
required
Query tensor with shape (batch, head, seq_len, dim).
key
torch.Tensor
required
Key tensor with shape (batch, head, seq_len, dim).
value
torch.Tensor
required
Value tensor with shape (batch, head, seq_len, dim).
attention_mask
torch.Tensor
default:"None"
Cumulative sequence lengths tensor for variable-length attention.
dropout
float
default:"0.0"
Dropout probability for attention weights.
scaling
float
default:"None"
Attention scaling factor.
sliding_window
int
default:"None"
Sliding window size for local attention.
softcap
float
default:"None"
Softcap value for attention logits.
Returns: Tuple of (attn_output, None) where attn_output is the attention output tensor.

qwen2vl_forward()

Custom forward method for Qwen2-VL and Qwen2.5-VL attention layers.
hidden_states
torch.Tensor
required
Input hidden states with shape (batch_size, seq_len, hidden_dim).
attention_mask
torch.Tensor
default:"None"
Attention mask tensor.
position_ids
torch.LongTensor
default:"None"
Position indices for positional embeddings.
past_key_values
Cache
default:"None"
Cached key/value states for efficient generation.
output_attentions
bool
default:"False"
Whether to return attention weights.
use_cache
bool
default:"False"
Whether to use key/value caching.
cache_position
torch.LongTensor
default:"None"
Position indices for cache updates.
position_embeddings
tuple[torch.Tensor, torch.Tensor]
default:"None"
Precomputed rotary position embeddings (cos, sin).
Returns: Tuple of (attn_output, attn_weights) where weights may be None.

qwen3vl_forward()

Custom forward method for Qwen3-VL and Qwen3-VL-MoE attention layers.
hidden_states
torch.Tensor
required
Input hidden states.
position_embeddings
tuple[torch.Tensor, torch.Tensor]
required
Rotary position embeddings (cos, sin).
attention_mask
torch.Tensor
default:"None"
Attention mask tensor.
past_key_values
Cache
default:"None"
Cached key/value states.
cache_position
torch.LongTensor
default:"None"
Position indices for cache updates.
Returns: Tuple of (attn_output, attn_weights).

Utility Functions

replace_qwen2_vl_attention_class()

Replaces the default attention forward methods in transformers with optimized Flash Attention implementations for all supported Qwen-VL model variants:
  • Qwen2-VL
  • Qwen2.5-VL
  • Qwen3-VL
  • Qwen3-VL-MoE
Also replaces causal mask creation functions to optimize for packed training. Prints the trainable status of vision module components. Output:
  • Trainable/non-trainable attention block indices
  • Merger module trainable status
Prints the trainable status of language model components. Output:
  • Embed tokens trainable status
  • Trainable/non-trainable decoder layer indices

Usage Example

from transformers import TrainingArguments
from qwenvl.train.trainer import replace_qwen2_vl_attention_class

# Apply Flash Attention optimizations
replace_qwen2_vl_attention_class()

# Configure training with separate learning rates
training_args = TrainingArguments(
    output_dir="./output",
    learning_rate=2e-5,
    mm_projector_lr=1e-4,  # Higher LR for projector
    vision_tower_lr=1e-5,  # Lower LR for vision tower
    weight_decay=0.01,
    # ... other arguments
)

Notes

  • Flash Attention 2 is used by default for all attention computations
  • The custom optimizer supports up to 6 parameter groups with different learning rates
  • Vision tower parameters are identified by “visual” in their name
  • Multimodal projector parameters are identified by “merger” in their name
  • Weight decay is not applied to bias parameters

Build docs developers (and LLMs) love