Documentation Index
Fetch the complete documentation index at: https://mintlify.com/AI-Hypercomputer/maxdiffusion/llms.txt
Use this file to discover all available pages before exploring further.
MaxDiffusion provides several optimization strategies to maximize training and inference performance on TPU and GPU hardware.
LIBTPU_INIT_ARGS flags
The LIBTPU_INIT_ARGS environment variable configures XLA compiler optimizations for TPU training. These flags control collective operations, memory management, and scheduling behavior.
Recommended configuration
For Wan2.1 training on TPU v5p:
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
For Wan inference and lighter workloads:
export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true"
For SDXL and Stable Diffusion training, you can disable these flags:
export LIBTPU_INIT_ARGS=""
Key flags explained
| Flag | Purpose |
|---|
xla_tpu_enable_async_collective_fusion | Enables fusion of async collective operations for better performance |
xla_enable_async_all_gather | Allows all-gather operations to run asynchronously |
xla_tpu_scoped_vmem_limit_kib | Sets virtual memory limit (65536 KiB = 64 MB) |
xla_tpu_enable_scheduler_memory_pressure_tracking | Optimizes scheduler based on memory usage |
xla_latency_hiding_scheduler_rerun | Reruns scheduler optimization passes |
Flash attention block sizes
Flash attention block sizes significantly impact memory usage and performance. Different TPU generations require different configurations.
TPU v6e (Trillium) - Wan models
flash_block_sizes='{
"block_q" : 3024,
"block_kv_compute" : 1024,
"block_kv" : 2048,
"block_q_dkv" : 3024,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 1024,
"block_q_dq" : 3024,
"block_kv_dq" : 2048,
"use_fused_bwd_kernel": False
}'
TPU v5p - Wan models
flash_block_sizes='{
"block_q" : 3024,
"block_kv_compute" : 1024,
"block_kv" : 2048,
"block_q_dkv" : 1024,
"block_kv_dkv" : 3072,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 3072
}'
Default configuration
For other models or when unsure:
flash_block_sizes: {
"block_q" : 512,
"block_kv_compute" : 512,
"block_kv" : 512,
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"use_fused_bwd_kernel": False
}
Setting flash attention
Enable flash attention in your training command:
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
flash_min_seq_length=0 \
flash_block_sizes='{...}'
Remat policies
Gradient checkpointing (rematerialization) trades computation for memory. MaxDiffusion supports several remat policies.
Available policies
- NONE - No gradient checkpointing (fastest, highest memory usage)
- FULL - Full gradient checkpointing (slowest, lowest memory usage)
- MATMUL_WITHOUT_BATCH - Checkpoint linear/matmul operations except those involving batch dimension
- OFFLOAD_MATMUL_WITHOUT_BATCH - Same as MATMUL_WITHOUT_BATCH but offloads to HBM instead of recomputing
- HIDDEN_STATE_WITH_OFFLOAD - Offloads hidden states (recommended for Wan training)
- CUSTOM - Define specific operations to save or offload
Configuration
Set the remat policy in your config file or command line:
remat_policy: 'HIDDEN_STATE_WITH_OFFLOAD'
Or via command line:
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD'
Custom policy
For fine-grained control, use CUSTOM policy:
remat_policy: "CUSTOM"
names_which_can_be_saved: ['attn_output', 'query_proj']
names_which_can_be_offloaded: ['xq_out', 'xk_out', 'ffn_activation']
Available annotations: attn_output, query_proj, key_proj, value_proj, xq_out, xk_out, ffn_activation
Data type optimization
Weight and activation dtypes
Choose dtypes based on your hardware and quality requirements:
# Recommended for TPU v5p/v6e
weights_dtype: bfloat16
activations_dtype: bfloat16
# For higher precision (slower)
weights_dtype: float32
activations_dtype: float32
Precision settings
Control matmul and conv precision:
# Options: DEFAULT, HIGH, HIGHEST
precision: "DEFAULT" # Fastest
precision: "HIGHEST" # Most accurate with fp32
Parallelism strategies
Wan models
Wan2.1 uses specialized parallelism:
ici_fsdp_parallelism - Sequence parallelism (try 2 or 4)
ici_tensor_parallelism - Head parallelism (must divide 40 evenly)
ici_data_parallelism - Data parallelism
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1
SDXL and Stable Diffusion
python -m src.maxdiffusion.train_sdxl \
src/maxdiffusion/configs/base_xl.yml \
per_device_batch_size=1 \
ici_data_parallelism=-1 # Auto-shard
Fractional batch sizes
Wan training supports fractional batch sizes:
per_device_batch_size: 0.25 # Effective global batch = 0.25 * num_devices
The result must be a whole number.
Caching latents
For faster training on small datasets, cache VAE latents and text encoder outputs:
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/cached_dataset'
HuggingFace transfer acceleration
Speed up model downloads:
export HF_HUB_ENABLE_HF_TRANSFER=1
JAX compilation cache
Avoid recompilation across runs:
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
jax_cache_dir=gs://your-bucket/jax_cache/
GPU-specific optimizations
For NVIDIA GPUs, use cudnn_flash_te attention:
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl \
src/maxdiffusion/configs/base_xl.yml \
hardware=gpu \
attention="cudnn_flash_te" \
weights_dtype=bfloat16
Batch parallelism
Enable batch parallelism on GPUs:
ici_fsdp_batch_parallelism: 2 # Does not support fractional batch sizes