Overview
Production training script for CIFAR-10 diffusion model with advanced features including exponential moving average (EMA), gradient accumulation, checkpoint resumption, and configurable early stopping.Usage
Environment variables
| Variable | Type | Default | Description |
|---|---|---|---|
EPOCHS | int | 2000 | Number of training epochs |
PATIENCE | int | 0 | Early stopping patience (epochs) |
EARLY_STOP | bool | false | Enable early stopping |
RESUME_FROM | str | None | Path to checkpoint to resume from |
RESUME_FROM_BEST | bool | true | Auto-resume from best checkpoint if exists |
WORK | str | ~ | Root directory for outputs |
Example with environment variables
Configuration parameters
Training hyperparameters
| Parameter | Type | Default | Description |
|---|---|---|---|
epochs | int | 2000 | Maximum number of training epochs |
batch_size | int | 256 | Batch size for training |
image_size | int | 32 | Image dimensions (32x32 for CIFAR-10) |
channels | int | 3 | Number of input channels (RGB) |
dropout_p | float | 0.1 | Dropout probability for model |
Model architecture
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims | list[int] | Configured in DiffusionProcessCIFAR | Hidden dimensions for U-Net |
noise_steps | int | 1000 | Number of diffusion timesteps |
Learning rate schedule
| Parameter | Type | Default | Description |
|---|---|---|---|
warmup_ratio | float | 0.05 | Warmup steps as fraction of total steps |
schedule | str | "cosine" | Learning rate schedule type |
Early stopping
| Parameter | Type | Default | Description |
|---|---|---|---|
patience | int | 0 (disabled) | Epochs to wait before stopping |
min_delta | float | 1e-5 | Minimum loss improvement threshold |
early_stop | bool | false | Enable early stopping |
Gradient accumulation
Configured automatically inDiffusionProcessCIFAR to simulate larger effective batch sizes.
Data preprocessing
Applies the following transformations to CIFAR-10 images:num_workers: 16pin_memory: Enabled on CUDApersistent_workers: Trueprefetch_factor: 2
Directory structure
$WORK directory: ~ (home directory)
Checkpoint management
Checkpoint contents
Each checkpoint contains:Checkpoint saving strategy
| Checkpoint | Frequency | Description |
|---|---|---|
checkpoint_epoch{n}.pt | Every 25 epochs | Regular checkpoint |
checkpoint_best.pt | On improvement | Best validation loss |
checkpoint_latest.pt | Every checkpoint save | Most recent state |
Auto-resume logic
- If
RESUME_FROMis set, resume from that checkpoint - Else if
RESUME_FROM_BEST=1andcheckpoint_best.ptexists, resume from best - Else if
checkpoint_latest.ptexists, resume from latest - Otherwise, start fresh training
Training process
Loss function
Mean squared error (MSE) between predicted and actual noise.Optimization
- Optimizer: AdamW (configured in
DiffusionProcessCIFAR) - EMA: Exponential moving average of model weights for better sample quality
- Gradient accumulation: Simulates larger batch sizes
- Mixed precision: Automatic if available
Training loop
Outputs
Generated files
| File | Description |
|---|---|
beta_schedule_cifar.png | Visualization of linear beta schedule |
noising_epoch{n}.png | Forward noising visualization (every 25 epochs) |
samples_epoch{n}.png | Generated samples from EMA model (every 25 epochs) |
training_curve_cifar.png | Training loss curve |
DDPM_CIFAR.png | Final DDPM samples (16 images) |
DDIM_CIFAR.png | Final DDIM samples (16 images, 50 steps) |
best_model_cifar.pt | Best EMA model state dict |
Visualization schedule
- Epoch 1: Initial noising and samples
- Every 25 epochs: Updated visualizations and checkpoint
- End of training: Final samples (DDPM and DDIM), loss curve, and checkpoint
Functions
get_device()
Select training device and apply CUDA-specific optimizations. Returns:torch.device
Optimizations:
- Enables cuDNN benchmark
- Sets high precision matmul on supported hardware
get_work_dirs()
Create and return output directory paths. Returns:tuple[str, str, str] - (project_work_dir, samples_dir, checkpoint_dir)
get_dataloader()
Build the CIFAR-10 training dataloader with augmentation. Parameters:image_size: Image dimensionbatch_size: Batch sizedevice: Training device
DataLoader
get_cosine_schedule_with_warmup()
Create a cosine learning rate schedule with linear warmup. Parameters:optimizer: PyTorch optimizernum_warmup_steps: Number of warmup stepsnum_training_steps: Total training steps
LambdaLR scheduler
visualize_noising()
Save a panel showing the forward noising process. Parameters:x0: Initial clean imagesdiffusion: DiffusionProcessCIFAR instancetimesteps: List of timesteps to visualizedevice: Training devicesave_dir: Output directoryfname: Output filename
save_checkpoint()
Save training checkpoint with all necessary state. Parameters:diffusion: DiffusionProcessCIFAR instanceoptimizer: PyTorch optimizerscheduler: LR schedulerepoch: Current epochloss_history: List of epoch lossesbest_loss: Best loss achievedwait: Early stopping countercheckpoint_dir: Checkpoint directoryis_best: Whether this is the best checkpoint
str - Path to saved checkpoint
load_checkpoint()
Load training checkpoint and restore state. Parameters:checkpoint_path: Path to checkpoint filediffusion: DiffusionProcessCIFAR instanceoptimizer: PyTorch optimizerscheduler: LR schedulerdevice: Training device
tuple[int, list[float], float, int] - (epoch, loss_history, best_loss, wait)
Code example
Custom training configuration
Resume interrupted training
Source
Location:src/training/train_diffusion_cifar.py