Overview
The U-Net architecture is the core neural network backbone used in diffusion models. It takes a noisy image and a timestep as input, and predicts the noise that should be removed. The U-Net’s symmetric encoder-decoder structure with skip connections makes it particularly effective for this task.
Architecture components
The U-Net consists of three main paths:
- Encoder (downsampling path) - Progressively reduces spatial dimensions while increasing channel depth
- Bottleneck - Processes features at the coarsest resolution with self-attention
- Decoder (upsampling path) - Reconstructs the output by upsampling and fusing encoder features via skip connections
Time embedding
Before processing the image, the timestep t is encoded into a high-dimensional embedding using sinusoidal position encoding followed by an MLP:
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim*4),
nn.GELU(),
nn.Linear(dim*4, dim),
)
def forward(self, t):
half_dim = self.dim // 2
# Frequencies for sin/cos encoding
freqs = torch.exp(
-torch.arange(half_dim, device=t.device) *
(torch.log(torch.tensor(10000.0, device=t.device)) / (half_dim - 1))
)
emb = t[:, None] * freqs[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return self.mlp(emb)
This time embedding is injected into every residual block, allowing the network to adjust its behavior based on the noise level.
The sinusoidal encoding is borrowed from the Transformer architecture and helps the model distinguish between different noise levels throughout the diffusion process.
Residual blocks
The fundamental building block is the ResBlock, which applies two convolution-normalization-activation sequences with a skip connection:
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_dim):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(8, in_ch),
nn.SiLU(),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
)
self.block2 = nn.Sequential(
nn.GroupNorm(8, out_ch),
nn.SiLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
)
self.time_emb = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, out_ch)
)
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = self.block1(x)
h = h + self.time_emb(t_emb)[:, :, None, None] # Inject time
h = self.block2(h)
return self.shortcut(x) + h # Residual connection
The time embedding is injected after the first convolution block, modulating the features based on the diffusion timestep.
Encoder path
The encoder progressively downsamples the input using DownBlock modules:
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_dim):
super().__init__()
self.res = ResBlock(in_ch, out_ch, time_dim)
self.pool = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
def forward(self, x, t_emb):
h = self.res(x, t_emb)
return self.pool(h), h # Return both pooled and pre-pool for skip
Each DownBlock returns two outputs:
- The downsampled features (passed to the next layer)
- The pre-downsampled features (saved as skip connections)
Bottleneck
The bottleneck processes features at the coarsest spatial resolution. It includes self-attention to capture global dependencies:
class BottleneckBlock(nn.Module):
def __init__(self, ch, time_dim):
super().__init__()
self.res1 = ResBlock(ch, ch, time_dim)
self.attn = SelfAttention(ch)
self.res2 = ResBlock(ch, ch, time_dim)
def forward(self, x, t_emb):
x = self.res1(x, t_emb)
x = self.attn(x)
x = self.res2(x, t_emb)
return x
Self-attention at the bottleneck is computationally expensive but crucial for modeling long-range dependencies in the image.
Decoder path
The decoder reconstructs the image using UpBlock modules that fuse upsampled features with skip connections:
class UpBlock(nn.Module):
def __init__(self, in_ch, skip_ch, out_ch, time_dim):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
self.res = ResBlock(out_ch + skip_ch, out_ch, time_dim)
def forward(self, x, skip, t_emb):
x = self.up(x)
x = torch.cat([x, skip], dim=1) # Fuse with skip connection
x = self.res(x, t_emb)
return x
The skip connections from the encoder provide high-resolution features that help the decoder reconstruct fine details.
Forward pass
The complete U-Net forward pass follows this flow:
def forward(self, x, t):
# 1. Embed timestep
t_emb = self.time_mlp(t)
# 2. Initial convolution
x = self.init_conv(x)
# 3. Encoder path - store skip connections
skips = []
for down in self.down_blocks:
x, skip = down(x, t_emb)
skips.append(skip)
# 4. Bottleneck
x = self.bottleneck(x, t_emb)
# 5. Decoder path - use skip connections in reverse
for up in self.up_blocks:
x = up(x, skips.pop(), t_emb)
# 6. Final convolution
x = self.out_conv(F.silu(self.out_norm(x)))
return x
Why U-Net for diffusion?
The U-Net architecture is well-suited for diffusion models because:
- Skip connections preserve fine-grained spatial information lost during downsampling
- Multi-scale processing allows the network to reason about both local textures and global structure
- Time conditioning enables the same network to denoise at different noise levels
- Residual connections facilitate gradient flow during training
The U-Net was originally designed for biomedical image segmentation but has become the standard architecture for diffusion models due to its effectiveness at image-to-image tasks.