Skip to main content

Overview

The DiffusionModel class implements a U-Net architecture with time embeddings for predicting noise in diffusion models. It features an encoder-decoder structure with skip connections, residual blocks, self-attention in the bottleneck, and sinusoidal time embeddings.

Constructor

DiffusionModel(
    image_size,
    channels,
    hidden_dims=[32, 64, 128],
    time_dim=128
)

Parameters

image_size
int
required
Height and width of the square input images.
channels
int
required
Number of image channels (e.g., 1 for grayscale, 3 for RGB).
hidden_dims
list[int]
default:"[32, 64, 128]"
List of hidden dimensions for each level of the U-Net. Determines the capacity and depth of the network. Each entry creates one down/up block pair.
time_dim
int
default:"128"
Dimensionality of the time embedding. This embedding is injected into each residual block to condition the network on the timestep.

Attributes

time_mlp
TimeEmbedding
Time embedding module that converts timestep integers to sinusoidal embeddings processed through an MLP.
init_conv
nn.Conv2d
Initial convolution layer mapping from channels to hidden_dims[0] channels with kernel size 3.
down_blocks
nn.ModuleList
List of downsampling blocks. Each DownBlock applies a residual block followed by 2x spatial downsampling.
bottleneck
BottleneckBlock
Bottleneck block at the coarsest resolution, containing two residual blocks with a self-attention layer in between.
up_blocks
nn.ModuleList
List of upsampling blocks. Each UpBlock performs 2x spatial upsampling, concatenates skip connections, and applies a residual block.
out_norm
nn.GroupNorm
Group normalization layer before the final convolution, with 8 groups.
out_conv
nn.Conv2d
Final convolution layer mapping from hidden_dims[0] back to channels with kernel size 3.

Methods

forward

Forward pass through the U-Net model.
def forward(self, x, t)

Parameters

x
torch.Tensor
required
Input noisy images tensor of shape [batch_size, channels, height, width].
t
torch.Tensor
required
Timesteps tensor of shape [batch_size] containing integer timestep indices.

Returns

noise_prediction
torch.Tensor
Predicted noise tensor of shape [batch_size, channels, height, width].

Implementation

The forward pass follows these steps:
  1. Time embedding: Converts timestep indices to sinusoidal embeddings via time_mlp
  2. Initial convolution: Projects input to the first hidden dimension
  3. Encoder path: Processes through down_blocks, storing skip connections
  4. Bottleneck: Applies residual blocks with self-attention at the coarsest resolution
  5. Decoder path: Processes through up_blocks, incorporating skip connections from the encoder
  6. Output: Applies group normalization, SiLU activation, and final convolution to predict noise

Architecture components

TimeEmbedding

Converts integer timesteps to continuous embeddings using sinusoidal positional encoding followed by an MLP.
  • Input: Timestep tensor [batch_size]
  • Output: Time embedding [batch_size, time_dim]
  • Structure: Sinusoidal encoding → Linear(dim, dim4) → GELU → Linear(dim4, dim)

ResBlock

Residual block with time embedding injection. Parameters:
  • in_ch (int): Input channels
  • out_ch (int): Output channels
  • time_dim (int): Time embedding dimension
Structure:
  • GroupNorm(8) → SiLU → Conv2d(3x3)
  • Add projected time embedding
  • GroupNorm(8) → SiLU → Conv2d(3x3)
  • Skip connection with optional 1x1 conv for channel matching

DownBlock

Downsampling block combining residual processing with spatial reduction. Parameters:
  • in_ch (int): Input channels
  • out_ch (int): Output channels
  • time_dim (int): Time embedding dimension
Returns: Tuple of (downsampled features, skip connection) Structure:
  • ResBlock → Conv2d(4x4, stride=2) for 2x downsampling

UpBlock

Upsampling block with skip connection fusion. Parameters:
  • in_ch (int): Input channels before upsampling
  • skip_ch (int): Skip connection channels
  • out_ch (int): Output channels
  • time_dim (int): Time embedding dimension
Structure:
  • ConvTranspose2d(4x4, stride=2) for 2x upsampling
  • Concatenate with skip connection
  • ResBlock

BottleneckBlock

Bottleneck processing at the coarsest resolution. Parameters:
  • ch (int): Channel dimension
  • time_dim (int): Time embedding dimension
Structure:
  • ResBlock → SelfAttention → ResBlock

SelfAttention

Multi-head self-attention layer for spatial feature relationships. Parameters:
  • ch (int): Channel dimension
  • num_heads (int, default=4): Number of attention heads
Structure:
  • GroupNorm → Multi-head attention → Projection → Residual connection

Usage example

import torch
from models.diffusion import DiffusionModel

# Initialize model for 32x32 RGB images
model = DiffusionModel(
    image_size=32,
    channels=3,
    hidden_dims=[64, 128, 256],
    time_dim=128
)

# Create sample inputs
batch_size = 8
noisy_images = torch.randn(batch_size, 3, 32, 32)
timesteps = torch.randint(0, 1000, (batch_size,))

# Forward pass
predicted_noise = model(noisy_images, timesteps)
print(predicted_noise.shape)  # Output: torch.Size([8, 3, 32, 32])

# Calculate loss
true_noise = torch.randn_like(noisy_images)
loss = torch.nn.functional.mse_loss(predicted_noise, true_noise)

Architecture diagram

For a model with hidden_dims=[32, 64, 128]:
Input (C, H, H) 
    ↓ init_conv
(32, H, H) ──────────────────────────┐
    ↓ DownBlock                       │ skip1
(64, H/2, H/2) ──────────────────┐   │
    ↓ DownBlock                   │   │
(128, H/4, H/4)                   │ skip2
    ↓ BottleneckBlock             │   │
(128, H/4, H/4)                   │   │
    ↓ UpBlock ←───────────────────┘   │
(64, H/2, H/2)                        │
    ↓ UpBlock ←───────────────────────┘
(32, H, H)
    ↓ out_norm → SiLU → out_conv
Output (C, H, H)
Time embeddings are injected into each ResBlock in the down, bottleneck, and up blocks.

Build docs developers (and LLMs) love