Skip to main content
Image segmentation assigns a class label to every pixel in an image, producing a dense output map rather than a single prediction. UNet, originally designed for biomedical image segmentation, has become a standard architecture across many visual domains.

Segmentation types

TypeDescriptionExample
SemanticEach pixel gets a class label; no distinction between instancesAll cars are “car”
InstanceSeparate mask per object instanceCar #1, Car #2
PanopticCombines semantic + instance segmentationBackground classes + counted objects

UNet architecture

UNet follows an encoder-decoder design with skip connections that copy feature maps from the encoder directly to the corresponding decoder level.
Input

  ├── Conv→Conv (64)  ──────────────────────────────────┐ skip
  │       ↓ MaxPool                                      │
  ├── Conv→Conv (128) ─────────────────────────────┐ skip│
  │       ↓ MaxPool                                 │    │
  ├── Conv→Conv (256) ────────────────────────┐ skip│    │
  │       ↓ MaxPool                           │    │    │
  ├── Conv→Conv (512) ───────────────────┐ skip│    │    │
  │       ↓ MaxPool                      │    │    │    │
  └── Bottleneck (1024)                  │    │    │    │
          ↓ UpConv                        │    │    │    │
        Concat ←──────────────────────────┘    │    │    │
        Conv→Conv (512)                        │    │    │
          ↓ UpConv                             │    │    │
        Concat ←───────────────────────────────┘    │    │
        Conv→Conv (256)                             │    │
          ↓ UpConv                                  │    │
        Concat ←────────────────────────────────────┘    │
        Conv→Conv (128)                                  │
          ↓ UpConv                                       │
        Concat ←─────────────────────────────────────────┘
        Conv→Conv (64)
          ↓ 1×1 Conv
        Output (num_classes)

Encoder path (contracting)

The encoder applies repeated blocks of:
  1. Two 3×33 \times 3 convolutions + ReLU
  2. 2×22 \times 2 max pooling (stride 2) — halves spatial dimensions, doubles channels

Bottleneck

The deepest layer captures the most abstract, high-level features without spatial pooling.

Decoder path (expanding)

The decoder mirrors the encoder:
  1. 2×22 \times 2 transposed convolution (upsampling) — doubles spatial dimensions
  2. Concatenation with the skip connection from the corresponding encoder level
  3. Two 3×33 \times 3 convolutions + ReLU

Output layer

A 1×11 \times 1 convolution maps the 64-channel feature maps to num_classes channels, followed by softmax (or sigmoid for binary segmentation).

PyTorch implementation

import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()
        # Encoder
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        # Decoder
        self.up4   = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4  = DoubleConv(1024, 512)
        self.up3   = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3  = DoubleConv(512, 256)
        self.up2   = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2  = DoubleConv(256, 128)
        self.up1   = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1  = DoubleConv(128, 64)
        self.out   = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b  = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self.up4(b),  e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out(d1)

Loss functions for segmentation

Cross-entropy loss

The standard choice for multi-class segmentation. Applied pixel-wise: LCE=1HWh,wcyh,w,clogp^h,w,c\mathcal{L}_{\text{CE}} = -\frac{1}{HW} \sum_{h,w} \sum_{c} y_{h,w,c} \log \hat{p}_{h,w,c} Can be weighted per class to handle class imbalance (background vs. small objects).

Dice loss

Dice loss directly optimizes the Dice coefficient (F1 score at pixel level), which is more robust to class imbalance: LDice=12ipigiipi2+igi2\mathcal{L}_{\text{Dice}} = 1 - \frac{2 \sum_i p_i g_i}{\sum_i p_i^2 + \sum_i g_i^2} where pip_i are predicted probabilities and gig_i are ground truth labels.
A combined loss L=LCE+LDice\mathcal{L} = \mathcal{L}_{\text{CE}} + \mathcal{L}_{\text{Dice}} is common in practice and often outperforms either alone.

Training and evaluation

model     = UNet(in_channels=3, num_classes=2).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images, masks = images.cuda(), masks.cuda()
        preds = model(images)           # (B, C, H, W)
        loss  = criterion(preds, masks) # masks: (B, H, W) long
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Evaluation metrics: mean IoU (mIoU) across classes is the standard benchmark. mIoU=1Cc=1CTPcTPc+FPc+FNc\text{mIoU} = \frac{1}{C} \sum_{c=1}^{C} \frac{\text{TP}_c}{\text{TP}_c + \text{FP}_c + \text{FN}_c}

Resources

UNet Segmentation Examples

Colab notebook with UNet segmentation on real datasets.

Exercise E08: Segmentation with UNet

Hands-on exercise: train UNet for image segmentation.

Original UNet Paper

Ronneberger et al. (2015) — the original UNet paper for biomedical image segmentation.

Video: UNet, GAN & Anomaly Detection

Recorded lecture covering UNet, GANs, and anomaly detection.

Build docs developers (and LLMs) love