Skip to main content

Overview

This guide will walk you through training your first diffusion model on MNIST. You’ll see how to:
  • Load and preprocess the MNIST dataset
  • Initialize the diffusion process and U-Net model
  • Train the model with early stopping
  • Generate new digit samples using DDPM and DDIM
Training on MNIST takes approximately 5-10 minutes on a GPU (or 30-60 minutes on CPU). The code automatically detects and uses CUDA if available.

Train MNIST DDPM

1

Run the training script

The simplest way to get started is to run the MNIST training script directly:
python src/training/train_diffusion.py
This will:
  • Automatically download MNIST to the data/ directory
  • Train a U-Net based DDPM with a cosine beta schedule
  • Save intermediate samples to samples/
  • Save the best model weights to best_model.pt
  • Generate a training loss curve at samples/training_curve.png
  • Create final samples at DDPM.png
2

Monitor training progress

During training, you’ll see output like this:
Training...: 100%|████████| 50/50 [05:23<00:00,  6.47s/it]
Epoch 1/50
Epoch 10/50
Epoch 20/50
Early stopping at epoch 24
Training complete. Samples saved to samples
The training loop includes:
  • Early stopping with patience=4 to prevent overfitting
  • Sample generation every 10 epochs to visualize progress
  • Loss tracking to monitor convergence
3

View generated samples

After training completes, check the DDPM.png file in the repository root. You should see a 4x4 grid of generated MNIST digits.The samples/ directory contains:
  • noising_epoch*.png - Forward diffusion visualization
  • samples_epoch*.png - Generated samples at different epochs
  • training_curve.png - Loss over time
  • beta_schedule.png - Cosine beta schedule visualization

Understanding the code

Data loading

The training script uses standard PyTorch data loading with normalization to [-1, 1]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # [0,1] → [-1,1]
])

dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=min(8, os.cpu_count() or 4),
    pin_memory=(device.type == "cuda"),
)
The normalization to [-1, 1] is important because the diffusion process expects inputs in this range.

Model initialization

The DiffusionProcess class handles both the diffusion schedule and the U-Net model:
diffusion = DiffusionProcess(
    image_size=28, 
    channels=1, 
    hidden_dims=[128, 256, 512], 
    device=device
)
Key parameters:
  • image_size=28 - MNIST images are 28x28 pixels
  • channels=1 - Grayscale images (CIFAR-10 uses channels=3)
  • hidden_dims=[128, 256, 512] - Channel dimensions for encoder/decoder levels
  • device - Automatically uses CUDA if available

Training loop

The training loop is simple and explicit:
for epoch in range(epochs):
    epoch_loss = 0.0
    for x, _ in loader:
        x = x.to(device)
        loss = diffusion.train_step(x)
        epoch_loss += loss
    
    avg_loss = epoch_loss / len(loader)
    loss_history.append(avg_loss)
    
    # Early stopping
    if avg_loss < best_loss - 1e-4:
        best_loss = avg_loss
        wait = 0
        torch.save(diffusion.model.state_dict(), "best_model.pt")
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
The train_step method handles the entire training step: sampling timesteps, adding noise, predicting noise, computing loss, and backpropagation.

What happens in train_step?

Here’s the core logic from src/models/diffusion.py:train_step:
def train_step(self, x):
    x = x.to(self.device)
    # 1. Sample random timesteps for each image
    t = torch.randint(0, self.noise_steps, (x.shape[0],), device=self.device)
    
    # 2. Add noise according to the diffusion schedule
    x_t, noise = self.add_noise(x, t)
    
    # 3. Predict the noise with the U-Net
    self.optimizer.zero_grad(set_to_none=True)
    with self.autocast_ctx():
        noise_pred = self.model(x_t, t)
        
        # 4. MSE loss between predicted and actual noise
        loss = F.mse_loss(noise_pred, noise)
    
    # 5. Backpropagation with optional mixed precision
    if self.grad_scaler.is_enabled():
        self.grad_scaler.scale(loss).backward()
        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
    else:
        loss.backward()
        self.optimizer.step()
    
    return loss.item()

Generate samples

Once you have a trained model (best_model.pt), you can generate new samples:
# Load trained model
diffusion = DiffusionProcess(image_size=28, channels=1, hidden_dims=[128, 256, 512])
diffusion.model.load_state_dict(torch.load("best_model.pt"))

# Generate samples with standard DDPM (1000 steps)
samples = diffusion.sample(num_samples=16)
samples = (samples + 1) / 2  # Denormalize to [0, 1]

utils.save_image(samples, "samples_ddpm.png", nrow=4)
DDIM with ddim_steps=50 is about 20x faster than DDPM while producing comparable quality. Use eta=0.0 for deterministic sampling or eta=1.0 to recover DDPM behavior.

Analyze the diffusion process

After training, run the interpolation and timestep analysis script:
python src/utilities/interpolation_and_timesteps.py
This will:
  • Estimate noise-prediction MSE vs timestep to see which parts of the diffusion chain are hardest to learn
  • Generate latent interpolations between random noise vectors using DDPM and DDIM
  • Save visualizations to interp.png and interp_ddim.png

Compare DDPM vs DDIM

Benchmark sampling speed and quality:
python src/utilities/ddim_comparison_mnist.py
This generates:
  • Sample grids at different step counts (10, 50, 100, 1000)
  • Timing analysis plots showing speed/quality trade-offs
  • A detailed analysis report at ddim_comparison_mnist/analysis_report.txt

DDPM

1000 steps ~5-10 seconds per batch Highest quality

DDIM (50 steps)

50 steps ~0.2-0.5 seconds per batch Near-identical quality

Visualizing the forward process

The visualize_noising utility shows how images are progressively corrupted:
@torch.no_grad()
def visualize_noising(x0, diffusion, timesteps=[0, 200, 400, 600, 800, 999]):
    x0 = x0[:8].to(device)
    fig, axes = plt.subplots(1, len(timesteps), figsize=(15, 2))
    for i, t in enumerate(timesteps):
        t_batch = torch.full((x0.size(0),), t, device=device, dtype=torch.long)
        x_t, _ = diffusion.add_noise(x0, t_batch)
        grid = utils.make_grid((x_t + 1) / 2, nrow=8)
        axes[i].imshow(grid.permute(1, 2, 0).cpu())
        axes[i].set_title(f"t={t}")
At t=0, you see the original image. By t=999, it’s pure Gaussian noise.

Next steps

Train on CIFAR-10

Try the more advanced CIFAR-10 model with EMA weights and multi-resolution attention
python src/training/train_diffusion_cifar.py

Experiment with hyperparameters

Modify hidden_dims, noise_steps, learning rate, or beta schedule in the source files to see how they affect training
CIFAR-10 training takes significantly longer (2000 epochs recommended) and requires a GPU. Consider using the provided SLURM scripts if you have access to a cluster.

Build docs developers (and LLMs) love