Skip to main content

Overview

Production training script for CIFAR-10 diffusion model with advanced features including exponential moving average (EMA), gradient accumulation, checkpoint resumption, and configurable early stopping.

Usage

python src/training/train_diffusion_cifar.py
The script supports configuration through environment variables.

Environment variables

VariableTypeDefaultDescription
EPOCHSint2000Number of training epochs
PATIENCEint0Early stopping patience (epochs)
EARLY_STOPboolfalseEnable early stopping
RESUME_FROMstrNonePath to checkpoint to resume from
RESUME_FROM_BESTbooltrueAuto-resume from best checkpoint if exists
WORKstr~Root directory for outputs

Example with environment variables

# Train for 500 epochs with early stopping
EPOCHS=500 EARLY_STOP=1 PATIENCE=50 python src/training/train_diffusion_cifar.py

# Resume from specific checkpoint
RESUME_FROM=/path/to/checkpoint.pt python src/training/train_diffusion_cifar.py

# Disable auto-resume from best checkpoint
RESUME_FROM_BEST=0 python src/training/train_diffusion_cifar.py

Configuration parameters

Training hyperparameters

ParameterTypeDefaultDescription
epochsint2000Maximum number of training epochs
batch_sizeint256Batch size for training
image_sizeint32Image dimensions (32x32 for CIFAR-10)
channelsint3Number of input channels (RGB)
dropout_pfloat0.1Dropout probability for model

Model architecture

ParameterTypeDefaultDescription
hidden_dimslist[int]Configured in DiffusionProcessCIFARHidden dimensions for U-Net
noise_stepsint1000Number of diffusion timesteps

Learning rate schedule

ParameterTypeDefaultDescription
warmup_ratiofloat0.05Warmup steps as fraction of total steps
schedulestr"cosine"Learning rate schedule type
Schedule formula:
steps_per_epoch = len(loader) // grad_accumulation_steps
total_steps = epochs * steps_per_epoch
warmup_steps = int(0.05 * total_steps)

Early stopping

ParameterTypeDefaultDescription
patienceint0 (disabled)Epochs to wait before stopping
min_deltafloat1e-5Minimum loss improvement threshold
early_stopboolfalseEnable early stopping

Gradient accumulation

Configured automatically in DiffusionProcessCIFAR to simulate larger effective batch sizes.

Data preprocessing

Applies the following transformations to CIFAR-10 images:
transforms.Compose([
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # [0,1] → [-1,1]
])
DataLoader configuration:
  • num_workers: 16
  • pin_memory: Enabled on CUDA
  • persistent_workers: True
  • prefetch_factor: 2

Directory structure

$WORK/stable-diffusion-cifar/
├── cifar_samples/              # Generated samples and visualizations
│   ├── beta_schedule_cifar.png
│   ├── noising_epoch*.png
│   ├── samples_epoch*.png
│   ├── training_curve_cifar.png
│   ├── DDPM_CIFAR.png
│   └── DDIM_CIFAR.png
├── checkpoints/                # Training checkpoints
│   ├── checkpoint_epoch*.pt
│   ├── checkpoint_best.pt
│   └── checkpoint_latest.pt
└── best_model_cifar.pt        # Best EMA model weights
Default $WORK directory: ~ (home directory)

Checkpoint management

Checkpoint contents

Each checkpoint contains:
{
    "epoch": int,
    "model_state_dict": dict,
    "ema_model_state_dict": dict,
    "optimizer_state_dict": dict,
    "scheduler_state_dict": dict,
    "loss_history": list[float],
    "best_loss": float,
    "wait": int,
    "accum_steps": int,
    "config": {
        "image_size": int,
        "channels": int,
        "hidden_dims": list[int],
        "noise_steps": int,
        "dropout_p": float
    }
}

Checkpoint saving strategy

CheckpointFrequencyDescription
checkpoint_epoch{n}.ptEvery 25 epochsRegular checkpoint
checkpoint_best.ptOn improvementBest validation loss
checkpoint_latest.ptEvery checkpoint saveMost recent state

Auto-resume logic

  1. If RESUME_FROM is set, resume from that checkpoint
  2. Else if RESUME_FROM_BEST=1 and checkpoint_best.pt exists, resume from best
  3. Else if checkpoint_latest.pt exists, resume from latest
  4. Otherwise, start fresh training

Training process

Loss function

Mean squared error (MSE) between predicted and actual noise.

Optimization

  • Optimizer: AdamW (configured in DiffusionProcessCIFAR)
  • EMA: Exponential moving average of model weights for better sample quality
  • Gradient accumulation: Simulates larger batch sizes
  • Mixed precision: Automatic if available

Training loop

for epoch in range(start_epoch, epochs):
    epoch_loss = 0.0
    for x, _ in loader:
        x = x.to(device)
        loss = diffusion.train_step(x)
        epoch_loss += loss
        if diffusion.accum_steps == 0:
            scheduler.step()
    
    avg_loss = epoch_loss / len(loader)
    # Save checkpoint if improved or at regular interval

Outputs

Generated files

FileDescription
beta_schedule_cifar.pngVisualization of linear beta schedule
noising_epoch{n}.pngForward noising visualization (every 25 epochs)
samples_epoch{n}.pngGenerated samples from EMA model (every 25 epochs)
training_curve_cifar.pngTraining loss curve
DDPM_CIFAR.pngFinal DDPM samples (16 images)
DDIM_CIFAR.pngFinal DDIM samples (16 images, 50 steps)
best_model_cifar.ptBest EMA model state dict

Visualization schedule

  • Epoch 1: Initial noising and samples
  • Every 25 epochs: Updated visualizations and checkpoint
  • End of training: Final samples (DDPM and DDIM), loss curve, and checkpoint

Functions

get_device()

Select training device and apply CUDA-specific optimizations. Returns: torch.device Optimizations:
  • Enables cuDNN benchmark
  • Sets high precision matmul on supported hardware

get_work_dirs()

Create and return output directory paths. Returns: tuple[str, str, str] - (project_work_dir, samples_dir, checkpoint_dir)

get_dataloader()

Build the CIFAR-10 training dataloader with augmentation. Parameters:
  • image_size: Image dimension
  • batch_size: Batch size
  • device: Training device
Returns: DataLoader

get_cosine_schedule_with_warmup()

Create a cosine learning rate schedule with linear warmup. Parameters:
  • optimizer: PyTorch optimizer
  • num_warmup_steps: Number of warmup steps
  • num_training_steps: Total training steps
Returns: LambdaLR scheduler

visualize_noising()

Save a panel showing the forward noising process. Parameters:
  • x0: Initial clean images
  • diffusion: DiffusionProcessCIFAR instance
  • timesteps: List of timesteps to visualize
  • device: Training device
  • save_dir: Output directory
  • fname: Output filename

save_checkpoint()

Save training checkpoint with all necessary state. Parameters:
  • diffusion: DiffusionProcessCIFAR instance
  • optimizer: PyTorch optimizer
  • scheduler: LR scheduler
  • epoch: Current epoch
  • loss_history: List of epoch losses
  • best_loss: Best loss achieved
  • wait: Early stopping counter
  • checkpoint_dir: Checkpoint directory
  • is_best: Whether this is the best checkpoint
Returns: str - Path to saved checkpoint

load_checkpoint()

Load training checkpoint and restore state. Parameters:
  • checkpoint_path: Path to checkpoint file
  • diffusion: DiffusionProcessCIFAR instance
  • optimizer: PyTorch optimizer
  • scheduler: LR scheduler
  • device: Training device
Returns: tuple[int, list[float], float, int] - (epoch, loss_history, best_loss, wait)

Code example

Custom training configuration

# Set custom work directory and train for 1000 epochs
WORK=/scratch/experiments \
EPOCHS=1000 \
EARLY_STOP=1 \
PATIENCE=100 \
python src/training/train_diffusion_cifar.py

Resume interrupted training

# Auto-resume from latest checkpoint
python src/training/train_diffusion_cifar.py

# Resume from specific checkpoint
RESUME_FROM=/path/to/checkpoint_epoch500.pt \
python src/training/train_diffusion_cifar.py

Source

Location: src/training/train_diffusion_cifar.py

Build docs developers (and LLMs) love