Skip to main content

Overview

Training script for a diffusion model on the MNIST dataset. Implements a simple DDPM with cosine beta scheduling, early stopping, and periodic visualization of the noising and sampling process.

Usage

python src/training/train_diffusion.py
The script runs with hardcoded configuration parameters and does not accept command-line arguments.

Configuration parameters

Training hyperparameters

ParameterTypeDefaultDescription
epochsint50Maximum number of training epochs
batch_sizeint128Batch size for training
image_sizeint28Image dimensions (28x28 for MNIST)
channelsint1Number of input channels (grayscale)
save_dirstr"samples"Directory for saving visualization outputs

Model architecture

ParameterTypeDefaultDescription
hidden_dimslist[int][128, 256, 512]Hidden dimensions for U-Net layers
noise_stepsint1000 (default)Number of diffusion timesteps

Early stopping

ParameterTypeDefaultDescription
patienceint4Number of epochs to wait before early stopping
min_deltafloat1e-4Minimum loss improvement threshold

Device configuration

  • Auto-detection: Uses CUDA if available, otherwise CPU
  • CUDA optimizations: Enables cuDNN benchmark and high precision matmul when available
  • Data loading: Automatically configures num_workers (min of 8 or CPU count)

Data preprocessing

Applies the following transformations to MNIST images:
transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # [0,1] → [-1,1]
])

Training process

Loss function

Mean squared error (MSE) between predicted and actual noise.

Optimization

  • Optimizer: Adam (configured in DiffusionProcess)
  • Loss tracking: Records average loss per epoch
  • Best model saving: Saves checkpoint when loss improves by > 1e-4

Early stopping logic

if avg_loss < best_loss - 1e-4:
    best_loss = avg_loss
    wait = 0
    torch.save(diffusion.model.state_dict(), "best_model.pt")
else:
    wait += 1
    if wait >= patience:  # patience = 4
        print(f"Early stopping at epoch {epoch+1}")
        break

Outputs

Generated files

FileDescription
samples/beta_schedule.pngVisualization of cosine beta schedule
samples/noising_epoch{n}.pngForward noising visualization (every 10 epochs)
samples/samples_epoch{n}.pngGenerated samples (every 10 epochs)
samples/training_curve.pngTraining loss curve
DDPM.pngFinal generated samples (16 images)
best_model.ptBest model checkpoint based on training loss

Visualization schedule

  • Epoch 1: Initial noising and samples
  • Every 10 epochs: Updated noising and samples
  • End of training: Final samples and loss curve

Code example

Customize training parameters by modifying the configuration section:
# Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100  # Increase training epochs
batch_size = 256  # Larger batch size
image_size = 28
channels = 1
save_dir = "my_samples"

# Model & Diffusion Process
diffusion = DiffusionProcess(
    image_size=image_size,
    channels=channels,
    hidden_dims=[128, 256, 512],  # Customize architecture
    device=device
)

Utilities

visualize_noising

Visualizes the forward noising process at specific timesteps. Parameters:
  • x0: Initial clean images (batch)
  • diffusion: DiffusionProcess instance
  • timesteps: List of timesteps to visualize (default: [0, 200, 400, 600, 800, 999])
  • fname: Output filename

visualize_sampling

Visualizes the reverse denoising process (sampling). Parameters:
  • diffusion: DiffusionProcess instance
  • num_samples: Number of samples to generate (default: 16)
  • steps: Timesteps to visualize (default: [999, 800, 400, 0])
  • fname: Output filename

Source

Location: src/training/train_diffusion.py

Build docs developers (and LLMs) love