Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/AI-Hypercomputer/maxdiffusion/llms.txt

Use this file to discover all available pages before exploring further.

This guide covers common issues you may encounter when using MaxDiffusion and how to resolve them.

Compilation issues

Symptoms: Training or inference hangs during compilation, or compilation takes over 30 minutes.Solutions:
  1. Use JAX compilation cache to avoid recompiling:
    python src/maxdiffusion/train_wan.py \
      src/maxdiffusion/configs/base_wan_14b.yml \
      jax_cache_dir=gs://your-bucket/jax_cache/
    
  2. Reduce model or batch size during initial testing:
    per_device_batch_size=0.125  # Smaller batch for faster compilation
    
  3. Check LIBTPU_INIT_ARGS - some flag combinations can slow compilation:
    # Try disabling all flags first
    export LIBTPU_INIT_ARGS=""
    
  4. Enable profiler to see where it’s stuck:
    enable_profiler: True
    skip_first_n_steps_for_profiler: 1
    
Symptoms: Errors like “Shape mismatch” or “XLA compilation failed”.Solutions:
  1. Verify parallelism settings match your hardware:
    # Check that product of ICI axes equals devices per slice
    ici_data_parallelism=2
    ici_fsdp_parallelism=4  # 2 * 4 = 8 devices
    ici_tensor_parallelism=1
    
  2. Check batch size divisibility:
    # Global batch must be evenly divisible by (data * fsdp) parallelism
    per_device_batch_size * num_devices % (ici_data_parallelism * ici_fsdp_parallelism) == 0
    
  3. For Wan models, verify head parallelism divides 40:
    # Valid values: 1, 2, 4, 5, 8, 10, 20, 40
    ici_tensor_parallelism=5  # OK
    ici_tensor_parallelism=3  # ERROR: 40 % 3 != 0
    
  4. Disable jit_initializers for debugging:
    jit_initializers: False  # Only for single-host debugging
    
Symptoms: Errors about bfloat16/float32 incompatibility.Solutions:
  1. Match weights and activations dtypes:
    weights_dtype: bfloat16
    activations_dtype: bfloat16
    
  2. Use float32 for higher precision (slower):
    weights_dtype: float32
    activations_dtype: float32
    precision: "HIGHEST"
    
  3. For GPU, ensure Transformer Engine is installed when using cudnn_flash_te:
    pip install "transformer_engine[jax]"
    NVTE_FUSED_ATTN=1 python src/maxdiffusion/train_sdxl.py ...
    

Out of memory (OOM) errors

Symptoms: “Out of memory” or “HBM allocation failed” errors.Solutions:
  1. Reduce batch size:
    per_device_batch_size=0.125  # Or even smaller like 0.0625
    
  2. Enable gradient checkpointing (rematerialization):
    remat_policy: "HIDDEN_STATE_WITH_OFFLOAD"  # For Wan
    remat_policy: "FULL"  # For maximum memory savings
    
  3. Use smaller flash block sizes:
    flash_block_sizes: {
      "block_q" : 512,
      "block_kv_compute" : 512,
      "block_kv" : 512,
      "block_q_dkv" : 512,
      "block_kv_dkv" : 512,
      "block_kv_dkv_compute" : 512,
      "block_q_dq" : 512,
      "block_kv_dq" : 512
    }
    
  4. Reduce resolution or number of frames:
    # For Wan models
    height=720  # Instead of 1280
    width=480   # Instead of 720
    num_frames=49  # Instead of 81
    
  5. Increase FSDP parallelism to shard model across more devices:
    ici_fsdp_parallelism=8  # More sharding = less memory per device
    
  6. For Wan, adjust scoped_vmem_limit:
    export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=32768"  # Reduce from 65536
    
Symptoms: OOM when loading pretrained weights.Solutions:
  1. Enable single replica checkpoint restoring:
    enable_single_replica_ckpt_restoring: True
    
  2. For Wan models, use external disk for HuggingFace cache:
    HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python ...
    
  3. Load weights in bfloat16:
    weights_dtype: bfloat16
    from_pt: True
    
Symptoms: OOM when creating TFRecord datasets.Solutions:
  1. Process in smaller batches:
    # In wan_txt2vid_data_preprocessing.py, reduce batch_size
    batch_size = 5  # Default is 10
    
  2. Increase number of shards:
    no_records_per_shard=5  # Smaller shards = less memory
    
  3. Use streaming dataset instead of in-memory:
    dataset_type: hf  # Instead of tf
    

Disk space issues

Symptoms: “No space left on device” errors.Solutions:
  1. Attach external disk to VM:
    # Follow: https://cloud.google.com/tpu/docs/attach-durable-block-storage
    # Then mount and use for cache:
    HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
    
  2. Save checkpoints to GCS instead of local disk:
    output_dir: gs://my-bucket/checkpoints/
    jax_cache_dir: gs://my-bucket/jax_cache/
    
  3. Disable checkpoint saving during debugging:
    checkpoint_every: -1
    save_final_checkpoint: False
    
  4. Clean up HuggingFace cache:
    rm -rf ~/.cache/huggingface/hub/*
    # Or set cache to GCS bucket
    
  5. Use smaller dataset or streaming:
    dataset_type: hf  # Streams data without downloading
    max_train_samples: 1000  # Limit dataset size
    
Symptoms: Disk full when downloading datasets from HuggingFace.Solutions:
  1. Use streaming dataset:
    dataset_type: hf  # No download needed
    dataset_name: BleachNick/UltraEdit_500k
    
  2. Download to external disk:
    export HF_DATASET_DIR=/mnt/disks/external_disk/datasets/
    huggingface-cli download RaphaelLiu/PusaV1_training --local-dir $HF_DATASET_DIR
    
  3. Download directly to GCS:
    # Download locally first, then upload and delete
    huggingface-cli download ... --local-dir /tmp/dataset
    gsutil -m cp -r /tmp/dataset gs://my-bucket/
    rm -rf /tmp/dataset
    

Permission and access errors

Symptoms: “401 Client Error: Unauthorized” or “Access denied”.Solutions:
  1. Obtain access to the model on HuggingFace (e.g., Flux, Wan).
  2. Create HuggingFace token:
  3. Set token in config or environment:
    hf_access_token: 'hf_xxxxxxxxxxxxxxxxxxxx'
    
    Or:
    export HF_TOKEN='hf_xxxxxxxxxxxxxxxxxxxx'
    huggingface-cli login --token $HF_TOKEN
    
Symptoms: “403 Forbidden” or “Permission denied” when accessing GCS buckets.Solutions:
  1. Authenticate gcloud:
    gcloud auth login
    gcloud auth application-default login
    
  2. Set project:
    gcloud config set project YOUR_PROJECT_ID
    
  3. Grant VM service account permissions:
    # Give Storage Admin role to TPU service account
    gcloud projects add-iam-policy-binding YOUR_PROJECT_ID \
      --member serviceAccount:SERVICE_ACCOUNT_EMAIL \
      --role roles/storage.admin
    
  4. Check bucket exists and is accessible:
    gsutil ls gs://my-bucket/
    
Symptoms: “Permission denied” when saving checkpoints locally.Solutions:
  1. Check directory permissions:
    ls -la /tmp/
    chmod 777 /tmp/output  # Or appropriate permissions
    
  2. Use home directory or /tmp:
    output_dir: /tmp/checkpoints/
    dataset_save_location: /tmp/dataset/
    
  3. Run with appropriate user:
    sudo chown -R $USER:$USER /path/to/output
    

Training and inference issues

Symptoms: Loss shows as NaN or increases dramatically.Solutions:
  1. Reduce learning rate:
    learning_rate: 1.e-6  # Instead of 1.e-5
    
  2. Enable gradient clipping:
    max_grad_norm: 1.0  # Default, try 0.5 for more aggressive clipping
    
  3. Use float32 instead of bfloat16:
    weights_dtype: float32
    activations_dtype: float32
    
  4. Check data preprocessing - ensure images/videos are normalized correctly.
  5. Reduce batch size - very large batches can cause instability.
Symptoms: Outputs are blurry, distorted, or don’t match prompts.Solutions:
  1. Increase inference steps:
    num_inference_steps=50  # Instead of 20
    
  2. Adjust guidance scale:
    guidance_scale=7.5  # Try values between 5-15
    
  3. For Wan models, set flow_shift:
    flow_shift=5.0  # Wan2.1 recommended value
    
  4. Use higher precision:
    weights_dtype: float32
    activations_dtype: float32
    
  5. Check if model loaded correctly - verify checkpoint path and weights.
Symptoms: Step time is much slower than expected.Solutions:
  1. Enable flash attention:
    attention='flash'
    flash_min_seq_length=0
    
  2. Optimize LIBTPU_INIT_ARGS - see optimization guide.
  3. Use appropriate flash block sizes for your TPU generation.
  4. Cache latents and text encodings:
    cache_latents_text_encoder_outputs: True
    
  5. Enable profiler to identify bottlenecks:
    enable_profiler: True
    skip_first_n_steps_for_profiler: 5
    profiler_steps: 10
    
  6. For GPU, use fused attention:
    NVTE_FUSED_ATTN=1 python ... attention="cudnn_flash_te"
    

Multihost issues

Symptoms: Training hangs when running on multiple hosts.Solutions:
  1. Enable distributed system initialization:
    skip_jax_distributed_system: False
    
  2. Ensure all hosts have same code version:
    # On all workers:
    cd maxdiffusion && git pull && pip install -e .
    
  3. Check DCN parallelism settings:
    dcn_data_parallelism=-1  # Auto-shard across slices
    dcn_fsdp_parallelism=1
    dcn_tensor_parallelism=1
    
  4. Verify network connectivity between hosts.
  5. Use GCS for checkpoints not local disk:
    output_dir: gs://my-bucket/output/
    
Symptoms: Slow step times with multiple hosts.Solutions:
  1. Ensure enough data files - need more files than hosts:
    # If 8 hosts, need at least 8+ TFRecord files
    no_records_per_shard=10  # Reduce to create more files
    
  2. Use GCS for data storage not local:
    train_data_dir: gs://my-bucket/dataset/
    
  3. Enable data shuffling:
    enable_data_shuffling: True
    

Getting help

If you’re still experiencing issues:
  1. Check the logs for detailed error messages
  2. Enable profiler to identify performance bottlenecks
  3. Search GitHub issues: https://github.com/AI-Hypercomputer/maxdiffusion/issues
  4. File a bug report with:
    • Complete error message and stack trace
    • Hardware type (TPU v5p, v6e, GPU model)
    • MaxDiffusion version and commit hash
    • Full command or config used
    • Steps to reproduce

Build docs developers (and LLMs) love