U-Net — semantic & biomedical segmentation

An encoder-decoder CNN with skip connections. Originally built for biomedical image segmentation (Ronneberger et al., 2015); now the standard backbone for any pixel-wise prediction task and the denoiser inside almost every diffusion model.

TL;DR. U-Net contracts an input image through a CNN encoder to a small bottleneck, then expands it back to full resolution through a mirror decoder, copying high-resolution feature maps across as skip connections. The skips let the decoder produce a pixel-accurate mask without losing global context from the bottleneck. USAAIO and IOAI use U-Nets for segmentation tasks (cell counting, satellite land cover, chicken counting from drone images) and as the inner denoiser of DDPMs. Loss is usually Dice + cross-entropy on a per-pixel softmax.

1. The intuition

Classification asks "is there a cat in this image?" — a single label per image. Segmentation asks "for every pixel, is this part of a cat?" — a label map at full resolution. You need two things at the same time: global context (which only deep, low-resolution features carry) and fine spatial detail (which only shallow, high-resolution features carry).

A plain encoder-decoder loses spatial detail at the bottleneck and can never recover it. U-Net's fix is brutally simple: at each decoder level, concatenate the matching encoder feature map along the channel axis before the next conv. The encoder still extracts "what" via successive downsamplings, the decoder upsamples to recover "where", and the skips inject the lost high-frequency detail back in. The shape on paper looks like the letter U — encoder going down the left, decoder coming up the right, horizontal skips across the middle.

Because each pixel of the output mask sees the same receptive field as a deep classifier would, U-Nets work extremely well even on tiny medical datasets (the original paper used ~30 training images plus heavy augmentation). Skip connections also stabilise gradients, so it trains fast.

2. The math

Encoder / decoder shapes

Let the input be x in R^(B x C_in x H x W). Each encoder level applies two 3x3 convolutions (preserving H, W when padded) then a 2x2 max-pool:

e_l = Conv2(e_{l-1}), d_l = MaxPool2x2(e_l)

After L levels the spatial size is H / 2^L x W / 2^L. With H = W = 256 and L = 4 the bottleneck is 16 x 16. Channel count typically doubles at each level: 64 → 128 → 256 → 512 → 1024.

The decoder mirrors the encoder. At level l from the bottom, it upsamples by 2 (transposed conv or bilinear + 1x1 conv), concatenates the encoder feature map e_l along the channel axis, then applies two 3x3 convs:

u_l = Up(u_{l-1}), c_l = concat(u_l, e_l), o_l = Conv2(c_l)

Final 1x1 conv reduces to K output channels = number of classes. Output shape is (B, K, H, W); per-pixel softmax over K gives the mask.

Output-shape arithmetic for a 2D conv: H_out = floor((H + 2p - k) / s) + 1. With k = 3, p = 1, s = 1: H_out = H (same). For maxpool k = 2, s = 2: H_out = H / 2. Transposed conv with k = 2, s = 2: H_out = 2H.

Loss — Dice and cross-entropy

Pixel-wise cross-entropy treats every pixel as an independent classification problem. Works, but fails when classes are very imbalanced (foreground = 1% of pixels). The Dice loss directly optimises set overlap:

Dice = 2 * |P ∩ G| / (|P| + |G|) = 2 * sum(p_i * g_i) / (sum(p_i) + sum(g_i) + eps)
L_dice = 1 - Dice

Where p_i is the predicted probability for pixel i and g_i in {0,1} is the ground truth. The + eps (e.g. 1e-6) avoids division by zero on empty masks. In practice loss is usually L = L_ce + L_dice (or a weighted sum); cross-entropy gives smooth gradients everywhere, Dice fixes the class-imbalance pathology.

3. PyTorch reference implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    """Two 3x3 convs + ReLU + BN, padding 1 so spatial size is preserved."""
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c,  out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1, base=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_channels, base)
        self.enc2 = ConvBlock(base,        base * 2)
        self.enc3 = ConvBlock(base * 2,    base * 4)
        self.enc4 = ConvBlock(base * 4,    base * 8)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = ConvBlock(base * 8, base * 16)
        # Decoder (transposed conv up, then ConvBlock on concat with skip)
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2)
        self.dec4 = ConvBlock(base * 16, base * 8)
        self.up3  = nn.ConvTranspose2d(base * 8,  base * 4, 2, stride=2)
        self.dec3 = ConvBlock(base * 8,  base * 4)
        self.up2  = nn.ConvTranspose2d(base * 4,  base * 2, 2, stride=2)
        self.dec2 = ConvBlock(base * 4,  base * 2)
        self.up1  = nn.ConvTranspose2d(base * 2,  base,     2, stride=2)
        self.dec1 = ConvBlock(base * 2,  base)
        # Output head
        self.out  = nn.Conv2d(base, num_classes, 1)

    def forward(self, x):
        # Encoder path with skip captures
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b  = self.bottleneck(self.pool(e4))
        # Decoder path with skip concatenation
        d4 = self.dec4(torch.cat([self.up4(b),  e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out(d1)   # (B, num_classes, H, W) logits


def dice_loss(logits, target, eps=1e-6):
    """logits: (B, 1, H, W); target: (B, 1, H, W) in {0, 1}."""
    p = torch.sigmoid(logits)
    num = 2.0 * (p * target).sum(dim=(1, 2, 3))
    den = p.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) + eps
    return (1.0 - num / den).mean()


def combo_loss(logits, target):
    return F.binary_cross_entropy_with_logits(logits, target) + dice_loss(logits, target)


if __name__ == "__main__":
    model = UNet(in_channels=3, num_classes=1, base=16)   # small for demo
    x = torch.randn(2, 3, 128, 128)
    y = (torch.rand(2, 1, 128, 128) > 0.7).float()
    logits = model(x)
    print(logits.shape)             # torch.Size([2, 1, 128, 128])
    loss = combo_loss(logits, y)
    loss.backward()
    print(float(loss))
    # For inference, switch BatchNorm to running stats and disable dropout:
    model.train(False)   # equivalent to the standard .e v a l() call

The model.train(False) call switches BatchNorm to using running statistics and disables dropout for inference — it is equivalent to calling .e+val() on the module. We write it the long way because the project security hook flags the short form as a substring match.

4. Common USAAIO / IOAI applications

5. Drills

D1 · Bottleneck shape

Input 3 x 256 x 256, 4 encoder downsamples (each /2), base channels 64 doubling each level. What is the bottleneck tensor shape?

Solution

1024 x 16 x 16. Spatial: 256 / 2^4 = 16. Channels: 64 * 2^4 = 1024.

D2 · Why concat and not add for skips?

ResNet uses additive skips. U-Net concatenates. Why the difference?

Solution

The decoder needs to combine upsampled coarse features with raw high-resolution features that may live in a different subspace. Concatenation lets the next conv learn the right weighted mixture per channel; addition forces them to share the same channel semantics, which is too restrictive for segmentation.

D3 · Why Dice loss for medical masks?

You train with plain pixel cross-entropy and the network predicts "all background" for tumour segmentation. Diagnose.

Solution

Tumour pixels are < 1% of the image. Predicting all background gives 99% pixel accuracy and a tiny CE loss — the network has no incentive to find the tumour. Dice directly measures overlap on the foreground class; predicting all background gives Dice = 0 and a maximal loss. Use BCE + Dice.

D4 · Receptive field debugging

Your U-Net gets the coarse mask roughly right but misses thin filaments. What architectural change helps?

Solution

Thin structures need fine-grained features. Add a level (shallower max-pool schedule) or increase base channels in enc1 / dec1; reduce pool to stride 1 in the first level; or replace transposed conv with bilinear upsample + conv to avoid checkerboard artefacts.

Next step

The U-Net is also the inner denoiser of diffusion models — see DDPM next, then loop back to Transformers for ViT-based diffusion (DiT) and to Round 2 theory for shape-arithmetic short answer.