Skip to main content
This guide covers training a DDPM model on CIFAR-10, a more challenging dataset with 32×32 color images. The implementation includes exponential moving average (EMA), gradient accumulation, checkpointing, and both DDPM and DDIM sampling.

Prerequisites

Install the required dependencies:
pip install torch torchvision matplotlib tqdm

Training configuration

The CIFAR-10 training uses a production-ready configuration:
epochs = 2000
batch_size = 256
image_size = 32
channels = 3  # RGB color images

Model architecture

The model uses a deeper U-Net with dropout for regularization:
diffusion = DiffusionProcessCIFAR(
    image_size=32,
    channels=3,
    dropout_p=0.1,
    device=device,
)

Data preparation

Dataset and augmentation

CIFAR-10 images are augmented with random horizontal flips and normalized to [-1, 1]:
src/training/train_diffusion_cifar.py
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # [0,1] -> [-1,1]
])

dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

Optimized dataloader

The dataloader is configured for maximum throughput:
src/training/train_diffusion_cifar.py
num_workers = 16
DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=(device.type == "cuda"),
    persistent_workers=True,
    prefetch_factor=2,
)
persistent_workers=True keeps worker processes alive between epochs, reducing startup overhead.

Training features

Exponential moving average (EMA)

EMA maintains a smoothed version of model weights for better sample quality:
# EMA is applied automatically during training
# The EMA model is used for sampling and checkpointing
samples = diffusion.sample(num_samples=16)  # Uses EMA model

Gradient accumulation

Gradient accumulation enables larger effective batch sizes on limited memory:
for x, _ in loader:
    x = x.to(device)
    loss = diffusion.train_step(x)  # Accumulates gradients
    epoch_loss += loss
    if diffusion.accum_steps == 0:  # Optimizer step occurred
        scheduler.step()

Learning rate schedule

Cosine annealing with warmup provides stable training:
src/training/train_diffusion_cifar.py
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step: int) -> float:
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

steps_per_epoch = len(loader) // diffusion.grad_accumulation_steps
total_steps = epochs * steps_per_epoch
warmup_steps = int(0.05 * total_steps)  # 5% warmup

Running training

1
Start training locally
2
Basic
python src/training/train_diffusion_cifar.py
With environment variables
EPOCHS=500 python src/training/train_diffusion_cifar.py
Resume from checkpoint
RESUME_FROM_BEST=1 EPOCHS=3000 python src/training/train_diffusion_cifar.py
3
Monitor training
4
Training progress is printed every epoch:
5
Epoch 1/2000 | loss=0.1234
Epoch 2/2000 | loss=0.1198
...
6
Visualize samples
7
Every 25 epochs, the script generates:
8
  • Noising visualization at timesteps [0, 200, 400, 600, 800, 999]
  • 16 samples from the EMA model
  • Checkpointing system

    Automatic checkpointing

    Checkpoints are saved every 25 epochs and when a new best loss is achieved:
    src/training/train_diffusion_cifar.py
    def save_checkpoint(
        diffusion, optimizer, scheduler, epoch,
        loss_history, best_loss, wait,
        checkpoint_dir, is_best=False
    ):
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": diffusion.model.state_dict(),
            "ema_model_state_dict": diffusion.ema_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "loss_history": loss_history,
            "best_loss": best_loss,
            "wait": wait,
        }
        torch.save(checkpoint, checkpoint_path)
    

    Checkpoint files

    Checkpoints are saved to $WORK/stable-diffusion-cifar/checkpoints/:
    • checkpoint_latest.pt - Most recent checkpoint
    • checkpoint_best.pt - Best loss checkpoint
    • checkpoint_epoch{N}.pt - Periodic checkpoints

    Resume training

    Resume from the best checkpoint:
    RESUME_FROM_BEST=1 python src/training/train_diffusion_cifar.py
    
    Or specify a checkpoint path:
    RESUME_FROM=/path/to/checkpoint.pt python src/training/train_diffusion_cifar.py
    

    Environment variables

    Customize training behavior with environment variables:
    VariableDefaultDescription
    EPOCHS2000Total epochs to train
    PATIENCE0Early stopping patience
    EARLY_STOP0Enable early stopping (1)
    RESUME_FROMNoneCheckpoint path to resume from
    RESUME_FROM_BEST1Resume from best checkpoint
    WORK~Working directory for outputs
    When resuming with EPOCHS, the value is the total epoch count, not additional epochs. If you stopped at epoch 1000 and want to train to epoch 3000, set EPOCHS=3000.

    Output locations

    All outputs are saved to $WORK/stable-diffusion-cifar/:
    $WORK/stable-diffusion-cifar/
    ├── checkpoints/
       ├── checkpoint_latest.pt
       ├── checkpoint_best.pt
       └── checkpoint_epoch{N}.pt
    ├── cifar_samples/
       ├── beta_schedule_cifar.png
       ├── noising_epoch{N}.png
       ├── samples_epoch{N}.png
       ├── training_curve_cifar.png
       ├── DDPM_CIFAR.png
       └── DDIM_CIFAR.png
    └── best_model_cifar.pt
    

    Sampling methods

    The trained model supports both DDPM and DDIM sampling:

    DDPM sampling (1000 steps)

    final_samples = diffusion.sample(num_samples=16)
    utils.save_image(
        torch.clamp((final_samples + 1) / 2, 0, 1),
        "DDPM_CIFAR.png",
        nrow=4,
    )
    

    DDIM sampling (50 steps)

    DDIM provides faster sampling with comparable quality:
    src/training/train_diffusion_cifar.py
    final_ddim = diffusion.sample_ddim(num_samples=16, ddim_steps=50)
    utils.save_image(
        torch.clamp((final_ddim + 1) / 2, 0, 1),
        "DDIM_CIFAR.png",
        nrow=4,
    )
    
    DDIM sampling is 20× faster than DDPM (50 steps vs 1000) with minimal quality loss.

    Performance optimization

    GPU optimizations

    src/training/train_diffusion_cifar.py
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        torch.backends.cudnn.benchmark = True
        if hasattr(torch, "set_float32_matmul_precision"):
            torch.set_float32_matmul_precision("high")
    

    Memory management

    export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
    
    This allows PyTorch to allocate more GPU memory as needed.

    Expected results

    With default hyperparameters:
    • Training time: 15-20 hours on A100 GPU for 2000 epochs
    • Memory usage: ~10-15 GB VRAM with batch_size=256
    • Final loss: ~0.005-0.010 MSE
    • Sample quality: Sharp CIFAR-10 images with correct colors and shapes

    Training curve

    Expect the loss to:
    1. Drop rapidly in the first 100 epochs
    2. Gradually decrease until ~500 epochs
    3. Slowly improve with diminishing returns after 1000 epochs

    Troubleshooting

    Out of memory errors

    Reduce batch size or enable gradient accumulation:
    batch_size = 128  # or 64
    

    Poor sample quality

    Ensure you’re using the EMA model for sampling. The raw model weights produce noisier samples.

    Training divergence

    If loss increases dramatically:
    1. Check learning rate (default: 1e-4 with AdamW)
    2. Reduce gradient accumulation
    3. Add gradient clipping

    Next steps

    • Scale training to HPC clusters: HPC SLURM Guide
    • Explore the diffusion model implementation: src/models/diffusion_cifar.py
    • Experiment with DDIM sampling: src/models/diffusion_cifar.py:sample_ddim()

    Build docs developers (and LLMs) love