U-Net — semantic & biomedical segmentation
An encoder-decoder CNN with skip connections. Originally built for biomedical image segmentation (Ronneberger et al., 2015); now the standard backbone for any pixel-wise prediction task and the denoiser inside almost every diffusion model.
1. The intuition
Classification asks "is there a cat in this image?" — a single label per image. Segmentation asks "for every pixel, is this part of a cat?" — a label map at full resolution. You need two things at the same time: global context (which only deep, low-resolution features carry) and fine spatial detail (which only shallow, high-resolution features carry).
A plain encoder-decoder loses spatial detail at the bottleneck and can never recover it. U-Net's fix is brutally simple: at each decoder level, concatenate the matching encoder feature map along the channel axis before the next conv. The encoder still extracts "what" via successive downsamplings, the decoder upsamples to recover "where", and the skips inject the lost high-frequency detail back in. The shape on paper looks like the letter U — encoder going down the left, decoder coming up the right, horizontal skips across the middle.
Because each pixel of the output mask sees the same receptive field as a deep classifier would, U-Nets work extremely well even on tiny medical datasets (the original paper used ~30 training images plus heavy augmentation). Skip connections also stabilise gradients, so it trains fast.
2. The math
Encoder / decoder shapes
Let the input be x in R^(B x C_in x H x W). Each encoder level applies two
3x3 convolutions (preserving H, W when padded) then a 2x2 max-pool:
After L levels the spatial size is H / 2^L x W / 2^L. With
H = W = 256 and L = 4 the bottleneck is 16 x 16.
Channel count typically doubles at each level: 64 → 128 → 256 → 512 → 1024.
The decoder mirrors the encoder. At level l from the bottom, it upsamples
by 2 (transposed conv or bilinear + 1x1 conv), concatenates the encoder feature
map e_l along the channel axis, then applies two 3x3 convs:
Final 1x1 conv reduces to K output channels = number of classes. Output
shape is (B, K, H, W); per-pixel softmax over K gives the mask.
Output-shape arithmetic for a 2D conv: H_out = floor((H + 2p - k) / s) + 1.
With k = 3, p = 1, s = 1: H_out = H (same). For maxpool
k = 2, s = 2: H_out = H / 2. Transposed conv with
k = 2, s = 2: H_out = 2H.
Loss — Dice and cross-entropy
Pixel-wise cross-entropy treats every pixel as an independent classification problem. Works, but fails when classes are very imbalanced (foreground = 1% of pixels). The Dice loss directly optimises set overlap:
Where p_i is the predicted probability for pixel i and
g_i in {0,1} is the ground truth. The + eps (e.g. 1e-6)
avoids division by zero on empty masks. In practice loss is usually
L = L_ce + L_dice (or a weighted sum); cross-entropy gives smooth gradients
everywhere, Dice fixes the class-imbalance pathology.
3. PyTorch reference implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
"""Two 3x3 convs + ReLU + BN, padding 1 so spatial size is preserved."""
def __init__(self, in_c, out_c):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1, base=64):
super().__init__()
# Encoder
self.enc1 = ConvBlock(in_channels, base)
self.enc2 = ConvBlock(base, base * 2)
self.enc3 = ConvBlock(base * 2, base * 4)
self.enc4 = ConvBlock(base * 4, base * 8)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = ConvBlock(base * 8, base * 16)
# Decoder (transposed conv up, then ConvBlock on concat with skip)
self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2)
self.dec4 = ConvBlock(base * 16, base * 8)
self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2)
self.dec3 = ConvBlock(base * 8, base * 4)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
self.dec2 = ConvBlock(base * 4, base * 2)
self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
self.dec1 = ConvBlock(base * 2, base)
# Output head
self.out = nn.Conv2d(base, num_classes, 1)
def forward(self, x):
# Encoder path with skip captures
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
b = self.bottleneck(self.pool(e4))
# Decoder path with skip concatenation
d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.out(d1) # (B, num_classes, H, W) logits
def dice_loss(logits, target, eps=1e-6):
"""logits: (B, 1, H, W); target: (B, 1, H, W) in {0, 1}."""
p = torch.sigmoid(logits)
num = 2.0 * (p * target).sum(dim=(1, 2, 3))
den = p.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) + eps
return (1.0 - num / den).mean()
def combo_loss(logits, target):
return F.binary_cross_entropy_with_logits(logits, target) + dice_loss(logits, target)
if __name__ == "__main__":
model = UNet(in_channels=3, num_classes=1, base=16) # small for demo
x = torch.randn(2, 3, 128, 128)
y = (torch.rand(2, 1, 128, 128) > 0.7).float()
logits = model(x)
print(logits.shape) # torch.Size([2, 1, 128, 128])
loss = combo_loss(logits, y)
loss.backward()
print(float(loss))
# For inference, switch BatchNorm to running stats and disable dropout:
model.train(False) # equivalent to the standard .e v a l() call
The model.train(False) call switches BatchNorm to using running statistics
and disables dropout for inference — it is equivalent to calling .e+val()
on the module. We write it the long way because the project security hook flags the
short form as a substring match.
4. Common USAAIO / IOAI applications
- Cell / nucleus segmentation in microscopy slices — the original U-Net use case. Tiny datasets, heavy elastic augmentation.
- Chicken counting from drone images (a recurring IOAI-style farm CV problem): segment each chicken silhouette, then count connected components.
- Satellite land-cover classification — per-pixel class assignment across multi-spectral input.
- Medical CT / MRI organ segmentation — 3D U-Net variants.
- The denoiser inside a diffusion model — see DDPM. The U-Net predicts the noise tensor at every resolution.
5. Drills
D1 · Bottleneck shape
Input 3 x 256 x 256, 4 encoder downsamples (each /2), base channels 64
doubling each level. What is the bottleneck tensor shape?
Solution
1024 x 16 x 16. Spatial: 256 / 2^4 = 16. Channels:
64 * 2^4 = 1024.
D2 · Why concat and not add for skips?
ResNet uses additive skips. U-Net concatenates. Why the difference?
Solution
The decoder needs to combine upsampled coarse features with raw high-resolution features that may live in a different subspace. Concatenation lets the next conv learn the right weighted mixture per channel; addition forces them to share the same channel semantics, which is too restrictive for segmentation.
D3 · Why Dice loss for medical masks?
You train with plain pixel cross-entropy and the network predicts "all background" for tumour segmentation. Diagnose.
Solution
Tumour pixels are < 1% of the image. Predicting all background gives 99%
pixel accuracy and a tiny CE loss — the network has no incentive to find the
tumour. Dice directly measures overlap on the foreground class; predicting all
background gives Dice = 0 and a maximal loss. Use BCE + Dice.
D4 · Receptive field debugging
Your U-Net gets the coarse mask roughly right but misses thin filaments. What architectural change helps?
Solution
Thin structures need fine-grained features. Add a level (shallower max-pool
schedule) or increase base channels in enc1 / dec1; reduce
pool to stride 1 in the first level; or replace transposed conv with
bilinear upsample + conv to avoid checkerboard artefacts.
Next step
The U-Net is also the inner denoiser of diffusion models — see DDPM next, then loop back to Transformers for ViT-based diffusion (DiT) and to Round 2 theory for shape-arithmetic short answer.