Documentation Index
Fetch the complete documentation index at: https://mintlify.com/AI-Hypercomputer/maxdiffusion/llms.txt
Use this file to discover all available pages before exploring further.
MaxDiffusion supports multiple data input pipelines for training. This guide covers dataset formats, preprocessing scripts, and best practices.
Dataset types
MaxDiffusion supports four dataset types, controlled by the dataset_type flag:
| Pipeline | Location | Formats | Features |
|---|
| hf | HuggingFace Hub or Cloud Storage | parquet, arrow, json, csv, txt | Streaming, good for large datasets |
| tf | HuggingFace Hub (downloads to disk) | parquet, arrow, json, csv, txt | In-memory, works for small datasets |
| tfrecord | Local/Cloud Storage | TFRecord | Streaming, good for large datasets |
| grain | Local/Cloud Storage | ArrayRecord | Streaming, global shuffle, deterministic |
HuggingFace streaming (dataset_type=hf)
Stream data directly from HuggingFace Hub or cloud storage without downloading.
From HuggingFace Hub
dataset_type: hf
dataset_name: BleachNick/UltraEdit_500k
image_column: source_image
caption_column: source_caption
train_split: FreeForm
hf_access_token: '' # For gated datasets
From cloud storage
dataset_type: hf
dataset_name: parquet # or json, arrow, etc.
hf_train_files: gs://my-bucket/my-dataset/*-train-*.parquet
tf.data in-memory (dataset_type=tf)
Downloads entire dataset to memory. Best for small datasets.
dataset_type: tf
dataset_name: diffusers/pokemon-gpt4-captions
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
cache_latents_text_encoder_outputs: True
When cache_latents_text_encoder_outputs=True, the VAE and text encoder process images and captions during dataset creation, saving preprocessed latents and embeddings.
Use TFRecord files for efficient streaming of large datasets.
dataset_type: tfrecord
train_data_dir: gs://my-bucket/my-dataset/ # Directory containing .tfrec files
Grain provides global shuffle and deterministic data iteration.
dataset_type: grain
grain_train_files: gs://my-bucket/my-dataset/*.arrayrecord
Wan dataset preprocessing
Wan models require special preprocessing to create TFRecord datasets with video latents and text embeddings.
Wan PusaV1 dataset example
This example uses the PusaV1 dataset.
Download the dataset
export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
Create training dataset
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
src/maxdiffusion/configs/base_wan_14b.yml \
train_data_dir=$HF_DATASET_DIR \
tfrecords_dir=$TFRECORDS_DATASET_DIR/train \
no_records_per_shard=10 \
enable_eval_timesteps=False
Check progress:
ls -ll $TFRECORDS_DATASET_DIR/train
Create evaluation dataset
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py \
src/maxdiffusion/configs/base_wan_14b.yml \
train_data_dir=$HF_DATASET_DIR \
tfrecords_dir=$TFRECORDS_DATASET_DIR/eval \
no_records_per_shard=10 \
enable_eval_timesteps=True
The evaluation dataset creates 420 samples with timestep annotations for quality evaluation as described in Scaling Rectified Flow Transformers.
Remove duplicates from training set
Delete the first 420 samples from training data (they’re in eval):
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | \
awk -F '[-.]' '$2+0 <= 420' | \
xargs -d '\n' rm
Verify deletion:
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | \
awk -F '[-.]' '$2+0 <= 420' | \
xargs -d '\n' echo
Clean up empty files
Remove any empty eval files:
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec 2>/dev/null || true
Directory structure
Your dataset should now have:
$TFRECORDS_DATASET_DIR/
├── train/
│ ├── file_00-10.tfrec
│ ├── file_01-20.tfrec
│ └── ...
└── eval_timesteps/
├── file_00-10.tfrec
├── file_01-20.tfrec
└── ...
General text-to-video preprocessing
For other video datasets, use the general preprocessing script:
python src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py \
src/maxdiffusion/configs/base_wan_14b.yml \
dataset_name="your-dataset/name" \
tfrecords_dir=$TFRECORDS_DATASET_DIR \
no_records_per_shard=10 \
caption_column="text" \
image_column="image" \
height=1280 \
width=720 \
seed=42
This script:
- Loads videos from HuggingFace datasets
- Encodes videos using the VAE
- Encodes captions using the T5 text encoder
- Saves latents and embeddings to TFRecord format
Configuration options
| Parameter | Description |
|---|
train_data_dir | Path to downloaded dataset |
tfrecords_dir | Output directory for TFRecord files |
no_records_per_shard | Number of examples per TFRecord file |
enable_eval_timesteps | Add timestep annotations for evaluation |
timesteps_list | Timesteps for evaluation buckets |
num_eval_samples | Number of evaluation samples (default: 420) |
Upload to cloud storage
Copy preprocessed data to GCS for distributed training:
BUCKET_NAME=my-bucket
gcloud storage cp --recursive $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
Using preprocessed data
For training
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
dataset_type='tfrecord' \
train_data_dir=gs://$BUCKET_NAME/wan_tfr_dataset_pusa_v1/train/ \
load_tfrecord_cached=True
For evaluation
python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
eval_every=100 \
eval_data_dir=gs://$BUCKET_NAME/wan_tfr_dataset_pusa_v1/eval_timesteps/
Multihost dataloading
In multihost environments, optimal performance requires each data file to be accessed by only one host.
Best practices
- Number of files > Number of hosts - Each host reads a subset of files
- File assignment - Files are distributed evenly across hosts
- Epoch handling - Hosts may finish epochs at different times
Resharding datasets
If you have fewer files than hosts, reshard your dataset:
# Increase no_records_per_shard to create fewer, larger files
no_records_per_shard=100 # Instead of 10
Or split existing files:
python -c "
import tensorflow as tf
import glob
files = glob.glob('dataset/*.tfrec')
target_shards = 64 # Number of output files
# Implementation to split files...
"
Synthetic data
For testing and benchmarking without real data:
dataset_type: 'synthetic'
synthetic_num_samples: null # Infinite samples
# Override dimensions
synthetic_override_height: 720
synthetic_override_width: 1280
synthetic_override_num_frames: 85
synthetic_override_max_sequence_length: 512
synthetic_override_text_embed_dim: 4096
synthetic_override_num_channels_latents: 16
synthetic_override_vae_scale_factor_spatial: 8
synthetic_override_vae_scale_factor_temporal: 4
Dataset configuration reference
Common parameters
image_column: 'image' # Column name for images/videos
caption_column: 'text' # Column name for captions
resolution: 1024 # Image resolution (for images)
height: 1280 # Video height (for videos)
width: 720 # Video width (for videos)
num_frames: 81 # Number of video frames
center_crop: False # Center crop images
random_flip: False # Random horizontal flip
enable_data_shuffling: True # Shuffle data during training
tokenize_captions_num_proc: 4 # Parallel workers for tokenization
transform_images_num_proc: 4 # Parallel workers for image processing
reuse_example_batch: False # Reuse same batch (for debugging)