Skip to main content

Overview

The U-Net architecture is the core neural network backbone used in diffusion models. It takes a noisy image and a timestep as input, and predicts the noise that should be removed. The U-Net’s symmetric encoder-decoder structure with skip connections makes it particularly effective for this task.

Architecture components

The U-Net consists of three main paths:
  1. Encoder (downsampling path) - Progressively reduces spatial dimensions while increasing channel depth
  2. Bottleneck - Processes features at the coarsest resolution with self-attention
  3. Decoder (upsampling path) - Reconstructs the output by upsampling and fusing encoder features via skip connections

Time embedding

Before processing the image, the timestep t is encoded into a high-dimensional embedding using sinusoidal position encoding followed by an MLP:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Linear(dim*4, dim),
        )
    
    def forward(self, t):
        half_dim = self.dim // 2
        # Frequencies for sin/cos encoding
        freqs = torch.exp(
            -torch.arange(half_dim, device=t.device) * 
            (torch.log(torch.tensor(10000.0, device=t.device)) / (half_dim - 1))
        )
        emb = t[:, None] * freqs[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.mlp(emb)
This time embedding is injected into every residual block, allowing the network to adjust its behavior based on the noise level.
The sinusoidal encoding is borrowed from the Transformer architecture and helps the model distinguish between different noise levels throughout the diffusion process.

Residual blocks

The fundamental building block is the ResBlock, which applies two convolution-normalization-activation sequences with a skip connection:
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(8, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
        )
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, out_ch)
        )
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    
    def forward(self, x, t_emb):
        h = self.block1(x)
        h = h + self.time_emb(t_emb)[:, :, None, None]  # Inject time
        h = self.block2(h)
        return self.shortcut(x) + h  # Residual connection
The time embedding is injected after the first convolution block, modulating the features based on the diffusion timestep.

Encoder path

The encoder progressively downsamples the input using DownBlock modules:
class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.res = ResBlock(in_ch, out_ch, time_dim)
        self.pool = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
    
    def forward(self, x, t_emb):
        h = self.res(x, t_emb)
        return self.pool(h), h  # Return both pooled and pre-pool for skip
Each DownBlock returns two outputs:
  • The downsampled features (passed to the next layer)
  • The pre-downsampled features (saved as skip connections)

Bottleneck

The bottleneck processes features at the coarsest spatial resolution. It includes self-attention to capture global dependencies:
class BottleneckBlock(nn.Module):
    def __init__(self, ch, time_dim):
        super().__init__()
        self.res1 = ResBlock(ch, ch, time_dim)
        self.attn = SelfAttention(ch)
        self.res2 = ResBlock(ch, ch, time_dim)
    
    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.attn(x)
        x = self.res2(x, t_emb)
        return x
Self-attention at the bottleneck is computationally expensive but crucial for modeling long-range dependencies in the image.

Decoder path

The decoder reconstructs the image using UpBlock modules that fuse upsampled features with skip connections:
class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, time_dim):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.res = ResBlock(out_ch + skip_ch, out_ch, time_dim)
    
    def forward(self, x, skip, t_emb):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)  # Fuse with skip connection
        x = self.res(x, t_emb)
        return x
The skip connections from the encoder provide high-resolution features that help the decoder reconstruct fine details.

Forward pass

The complete U-Net forward pass follows this flow:
def forward(self, x, t):
    # 1. Embed timestep
    t_emb = self.time_mlp(t)
    
    # 2. Initial convolution
    x = self.init_conv(x)
    
    # 3. Encoder path - store skip connections
    skips = []
    for down in self.down_blocks:
        x, skip = down(x, t_emb)
        skips.append(skip)
    
    # 4. Bottleneck
    x = self.bottleneck(x, t_emb)
    
    # 5. Decoder path - use skip connections in reverse
    for up in self.up_blocks:
        x = up(x, skips.pop(), t_emb)
    
    # 6. Final convolution
    x = self.out_conv(F.silu(self.out_norm(x)))
    
    return x

Why U-Net for diffusion?

The U-Net architecture is well-suited for diffusion models because:
  1. Skip connections preserve fine-grained spatial information lost during downsampling
  2. Multi-scale processing allows the network to reason about both local textures and global structure
  3. Time conditioning enables the same network to denoise at different noise levels
  4. Residual connections facilitate gradient flow during training
The U-Net was originally designed for biomedical image segmentation but has become the standard architecture for diffusion models due to its effectiveness at image-to-image tasks.

Build docs developers (and LLMs) love