Documentation Index
Fetch the complete documentation index at: https://mintlify.com/AI-Hypercomputer/maxdiffusion/llms.txt
Use this file to discover all available pages before exploring further.
XPK (Accelerated Processing Kit) simplifies running MaxDiffusion on Google Kubernetes Engine (GKE) for both experimentation and production workloads.
Prerequisites
Verify you have these permissions for your account or service account:
- Storage Admin
- Kubernetes Engine Admin
Setup XPK
Install system dependencies
Install kubectl and gke-gcloud-auth-plugin:sudo apt-get update
sudo apt install snapd
sudo snap install kubectl --classic
Install the GKE authentication plugin:echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
sudo apt update && sudo apt-get install google-cloud-sdk-gke-gcloud-auth-plugin
Authenticate gcloud
Authenticate your gcloud installation:Configure Docker to use gcloud credentials:gcloud auth configure-docker us-docker.pkg.dev
Test the Docker installation:If you get a permission error, run:sudo usermod -aG docker $USER
Then log out and log back in to the machine. Install XPK
Install XPK using pip:Alternatively, clone the XPK repository:git clone https://github.com/google/xpk.git
Build Docker image
Clone MaxDiffusion
Clone the MaxDiffusion repository:git clone https://github.com/google/MaxDiffusion.git
cd MaxDiffusion
Build dependency image
Build the MaxDiffusion base image. This only needs to be rerun when you change dependencies:# Default will pick stable versions of dependencies
bash docker_build_dependency_image.sh
Using JAX AI Images
Build the MaxDiffusion Docker image using JAX AI base images for a more reliable build environment:bash docker_build_dependency_image.sh \
MODE=jax_ai_image \
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
JAX AI Images is currently in the experimental phase.
Run workloads
After building the maxdiffusion_base_image, XPK can handle updates to the working directory when running workloads.
When using XPK, include pip install . in your command to install the package from the current directory. This ensures local changes are applied within the container.
gcloud config set project $PROJECT_ID
gcloud config set compute/zone $ZONE
# Create GCS buckets for outputs and datasets
BASE_OUTPUT_DIR=gs://output_bucket/
DATASET_PATH=gs://dataset_bucket/
Create workload
Using pip-installed XPK
Using XPK repository
xpk workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image maxDiffusion_base_image \
--workload ${USER}-first-job \
--tpu-type=v4-8 \
--num-slices=1 \
--command "pip install . && python src/maxdiffusion/train.py \
src/maxdiffusion/configs/base_2_base.yml \
run_name='my_run' \
output_dir='gs://your-bucket/'"
python3 xpk/xpk.py workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image maxDiffusion_base_image \
--workload ${USER}-first-job \
--tpu-type=v4-8 \
--num-slices=1 \
--command "pip install . && python src/maxdiffusion/train.py \
src/maxdiffusion/configs/base_2_base.yml \
run_name='my_run' \
output_dir='gs://your-bucket/'"
Advanced usage
Large-scale training example
For large-scale Wan 2.1 training on v5p-256:
RUN_NAME=wan-training-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/tfrecords_dataset/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/tfrecords_dataset/eval_timesteps/
LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true'
python3 ~/xpk/xpk.py workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=v5p-256 \
--num-slices=1 \
--command="HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
max_train_steps=5000" \
--base-docker-image=${IMAGE_DIR} \
--workload=${RUN_NAME}
Resources