Attention & transformers
Tokenization, embeddings, scaled dot-product attention, multi-head attention, the transformer block, and how the same architecture powers modern NLP, vision (ViT), and even graph and audio models.
nn.Linear
+ nn.LayerNorm; (4) distinguish encoder-only (BERT), decoder-only (GPT), and encoder-decoder
(T5) and pick one for a given task; (5) describe how ViTs, U-Nets, VAEs, and diffusion models reuse or
extend the same building blocks for vision and generation. USAAIO transformer questions test concept
fluency, shape arithmetic, and small implementations.
1. Tokenization
Concept
Models don't see text — they see integer token IDs. The tokenizer converts strings to IDs and back. Three strategies:
- Word-level: one ID per whitespace-separated word. Simple but breaks on unseen words (out-of-vocabulary), inflates vocabulary to hundreds of thousands, and can't share morphology across "run", "running", "runner".
- Character-level: tiny vocabulary (~100), no OOV, but sequence length explodes and the model has to relearn that "th" appears together a lot.
- Subword (BPE / WordPiece / SentencePiece): the modern default. Start from characters and greedily merge the most common adjacent pairs until reaching a target vocabulary (typically 32K–128K). Common words stay whole; rare words split into 2–5 subword pieces. No OOV — any Unicode string is representable.
Worked example — BPE merges by hand
Starting vocabulary {l, o, w, e, r, n} and corpus {"low": 5, "lower": 2, "newer":
6}. Each word is initially split into characters with a special end-of-word marker (omitted
here for brevity).
- Pair counts after step 0:
(l,o):7, (o,w):7, (w,e):8, (e,r):8, (n,e):6, (e,w):6. - Most frequent:
(w,e)tied with(e,r)— mergew + e -> we. - Next merge:
(e,r) -> er(still 8 occurrences after the previous merge). - After two merges, "newer" tokenizes as
n, e, we, ron step 1 and asn, e, w, eron the alternative path; modern BPE picks a deterministic priority by rank.
from collections import Counter
# Minimal BPE training loop sketch (do NOT use in production — just illustrates the idea)
def get_pairs(word_freqs):
pairs = Counter()
for word, freq in word_freqs.items():
symbols = word.split()
for a, b in zip(symbols, symbols[1:]):
pairs[(a, b)] += freq
return pairs
def merge(pair, word_freqs):
new = {}
bigram = " ".join(pair)
repl = "".join(pair)
for w, f in word_freqs.items():
new[w.replace(bigram, repl)] = f
return new
vocab = {"l o w </w>": 5, "l o w e r </w>": 2, "n e w e r </w>": 6}
for _ in range(4):
pairs = get_pairs(vocab)
best = max(pairs, key=pairs.get)
vocab = merge(best, vocab)
print(vocab)
Drills
D1 · Why subword beats word
Give two concrete advantages of BPE over word-level tokenization.
Solution
(1) No OOV: any new word becomes a sequence of known subwords. (2) Morphological generalisation:
"running" -> run + ning shares the "run" embedding with "runs" and "runner".
D2 · Sequence-length cost
A 1 000-character English sentence becomes ~200 BPE tokens; the same sentence is ~5 000 byte-level tokens. Why does this matter for compute?
Solution
Attention is O(n^2) in sequence length. A 25x sequence length is a 625x attention
cost. Subword tokenization is largely a compute optimisation.
2. Embeddings & positional encodings
Concept
Each token ID is looked up in an embedding table E in R^(V x d) to produce a dense vector
(typically d = 256 to 4096). Embeddings are learned during training; semantically similar
tokens end up nearby in vector space.
Self-attention is permutation-equivariant — it doesn't know token order. Position is injected explicitly:
- Learned positional embeddings (original GPT, BERT): a second table indexed by position, added to the token embedding.
- Sinusoidal (original Transformer): fixed
PE(pos, 2i) = sin(pos / 10000^(2i/d)),PE(pos, 2i+1) = cos(...). No parameters; extrapolates to longer sequences in principle. - Rotary (RoPE): rotates pairs of dimensions in Q and K by an angle proportional to position. Used in LLaMA, Mistral, modern LLMs. Encodes relative positions naturally.
Worked example — embedding + sinusoidal PE
import torch, torch.nn as nn, math
class SinusoidalPE(nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe) # not a parameter
def forward(self, x): # x: (B, T, d_model)
return x + self.pe[:x.size(1)]
vocab_size, d_model = 32000, 512
emb = nn.Embedding(vocab_size, d_model)
pos = SinusoidalPE(max_len=2048, d_model=d_model)
ids = torch.tensor([[1, 5, 9, 4]])
x = pos(emb(ids)) # (1, 4, 512)
print(x.shape)
Drills
D1 · Why scale-invariant frequencies?
Why does sinusoidal PE use frequencies geometric in i (factor 10000)?
Solution
Geometric frequencies give the network a multi-scale "clock" — short frequencies encode local position, long frequencies encode coarse position. Differences of positions map linearly to phase differences, so the model can learn relative offsets easily.
D2 · Why not just one-hot positions?
One sentence.
Solution
One-hot positions can't generalise to lengths unseen at training time and use far more dimensions; smooth periodic or rotary encodings interpolate naturally.
3. Scaled dot-product attention
Concept
Given Queries Q in R^(T x d_k), Keys K in R^(T x d_k), Values
V in R^(T x d_v) (typically d_k = d_v):
- Q K^T measures how strongly each query "matches" each key — a
T x Tsimilarity matrix. - / sqrt(d_k) prevents the dot products from growing too large at high dimension (which would saturate the softmax and produce near-one-hot weights, killing gradient flow).
- softmax over the key axis turns scores into a probability distribution: each query reads a weighted mixture of values.
- V multiplication takes the weighted average of value vectors according to those probabilities.
For causal (decoder) attention, mask the upper triangle with -inf before
softmax so position t can only attend to positions <= t. For
padded batches, mask out padding tokens.
Worked example — attention from first principles
import torch, torch.nn.functional as F
def attention(Q, K, V, mask=None):
# Q, K, V: (batch, n_heads, T, d_head)
d_k = Q.size(-1)
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return weights @ V, weights
# Tiny sanity check: identical Q and K -> each position attends mostly to itself
T, d = 4, 8
Q = K = V = torch.randn(1, 1, T, d)
out, w = attention(Q, K, V)
print(w[0,0]) # diagonal should dominate
Drills
D1 · Why divide by sqrt(d_k)?
What concretely goes wrong if you omit the scaling?
Solution
For independent Q, K with unit variance, each entry of Q K^T has
variance d_k. As d_k grows, the softmax becomes extremely peaked
(effectively argmax), so gradients vanish for non-selected positions. Dividing by
sqrt(d_k) keeps the variance ~ 1.
D2 · Causal mask shape
For sequence length T = 4, write the causal mask as a 4x4 matrix of 1s and 0s where 1
means "allowed".
Solution
1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1
Lower triangular including the diagonal. Apply by setting positions where the mask is 0 to
-inf in the score matrix before softmax.
D3 · Attention complexity
Sequence length T = 4096, d_model = 512. How does compute scale with
T? With d_model?
Solution
O(T^2 * d_model) for the Q K^T and softmax * V products.
Doubling T quadruples compute; doubling d_model doubles it. Long-sequence
tricks (FlashAttention, sliding window, sparse attention) attack the T^2 term.
4. Multi-head attention
Concept
Instead of one attention with full d_model dimensions, split into h heads of
size d_head = d_model / h. Each head has its own learned W_Q, W_K, W_V and
can attend to a different relationship (e.g. one head tracks subject-verb agreement, another tracks
coreference). Outputs are concatenated and projected by W_O.
Worked example — multi-head attention module
import torch.nn as nn
# attention() from above
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, D = x.shape
# project and reshape to (B, n_heads, T, d_head)
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
out, _ = attention(Q, K, V, mask) # (B, n_heads, T, d_head)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
Drills
D1 · Param count of MHA
For d_model = 512, n_heads = 8, how many parameters in the four projection
matrices (ignoring biases)?
Solution
Each of W_Q, W_K, W_V, W_O is 512 x 512 = 262 144. Total
4 * 262 144 = 1 048 576 = ~1.0 M parameters (plus 4*512 biases).
D2 · Why split into heads?
If the total parameter count is the same as one big attention with d_head = d_model,
what's the point of splitting?
Solution
Each head has its own subspace and its own softmax — it can specialise. Empirically, h heads each attending to a different relation outperform one large head, while computing the same total work.
5. The transformer block
Concept
A transformer block applies, in order: LayerNorm -> multi-head self-attention -> residual; LayerNorm -> position-wise feed-forward MLP -> residual. The MLP is typically 2 linear layers with a 4x widening factor and a GELU activation. Residual connections are critical — they let the model add small refinements at each layer and keep gradients flowing through deep stacks.
Two orderings exist: post-LN (original, harder to train deeply) and pre-LN (modern default, trains stably to hundreds of layers). The example below uses pre-LN.
Worked example — encoder block
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.drop(self.attn(self.norm1(x), mask))
x = x + self.drop(self.ff(self.norm2(x)))
return x
Drills
D1 · LayerNorm vs BatchNorm
Why do transformers use LayerNorm, not BatchNorm?
Solution
BatchNorm normalises across the batch dimension; for variable-length sequences and small batches (common in NLP) the statistics are unreliable. LayerNorm normalises within a single sample, across the feature dimension — independent of batch size and sequence padding.
D2 · Residual + LN math
With pre-LN, what tensor enters the attention layer: x or
LayerNorm(x)?
Solution
LayerNorm(x). The residual added back is the un-normalised x. This
ordering keeps the residual stream's scale stable across many layers.
6. Pre-training, fine-tuning, prompting
Concept
Transformers achieve their power by being pre-trained on massive unlabeled corpora with self-supervised objectives, then adapted to downstream tasks.
- Causal (autoregressive) LM — predict the next token. The objective behind GPT, LLaMA, Mistral, and all decoder-only chat models. Trained on trillions of tokens; emerges few-shot capability at scale.
- Masked language modeling (MLM) — mask ~15% of tokens at random and predict them. The objective behind BERT and its descendants. Bidirectional context; great encoder for classification / retrieval but cannot generate.
- Fine-tuning — continue training the pre-trained model on a small labeled dataset for the target task. Tiny LR (1e-5 to 5e-5), few epochs (2-5), early stopping.
- Parameter-efficient fine-tuning (PEFT) — LoRA, adapters. Freeze the base model; train only a small set of additional parameters (often < 1% of total). Cheap, fast, nearly matches full fine-tuning.
- Prompting / in-context learning — for sufficiently large models, the task is described in natural language with examples and no weights change at all.
7. Encoder vs decoder vs encoder-decoder
| Architecture | Attention | Examples | Best for |
|---|---|---|---|
| Encoder-only | Bidirectional self-attn | BERT, RoBERTa, DeBERTa | Classification, NER, retrieval, embeddings |
| Decoder-only | Causal self-attn | GPT, LLaMA, Mistral | Generation, chat, code, in-context learning |
| Encoder-decoder | Encoder self-attn + decoder causal self-attn + cross-attn | T5, BART, Whisper | Translation, summarisation, ASR |
Cross-attention in encoder-decoder models is identical to self-attention except Q comes from the decoder's current state while K and V come from the encoder's outputs.
Drills
D1 · Pick the architecture
(a) Translate French to English. (b) Detect toxic comments. (c) Build a chatbot. (d) Score sentence similarity for search.
Solution
(a) encoder-decoder (T5/BART); (b) encoder-only (BERT) + linear head; (c) decoder-only (GPT-style); (d) encoder-only with [CLS] pooling or a sentence-encoder like Sentence-BERT.
D2 · Cross-attention shape
Decoder generating token t with state d_t in R^d; encoder outputs
H in R^(S x d). What are the shapes of Q, K, V in cross-attention?
Solution
Q has shape (1, d) projected from d_t; K, V
have shape (S, d) projected from H. The decoder reads a weighted summary of
encoder outputs.
8. Vision transformers & generative models
Concept
Vision Transformer (ViT): chop a 224x224 image into 196 patches of 16x16, flatten each
patch to a 768-dim vector via a linear projection, prepend a learned [CLS] token, add
positional embeddings, and run a stack of transformer encoder blocks. At ImageNet-21k scale and above,
ViTs beat ResNets; below that, CNNs win because they bake in translation equivariance as an inductive
bias.
Generative & CV applications (per the syllabus):
- U-Net — encoder-decoder CNN with skip connections; the workhorse for segmentation and the denoising backbone inside almost every diffusion model.
- VAE (Variational Auto-Encoder) — encoder maps
xto a Gaussian over latentz; decoder reconstructsxfrom sampledz. Loss = reconstruction + KL to prior. Used inside Stable Diffusion to compress images to a tractable latent space. - DDPM (Denoising Diffusion Probabilistic Model) — gradually adds Gaussian noise to data (forward process), trains a network to predict the noise at each step (typically a U-Net conditioned on time and prompt), then samples by reversing the process. State of the art for image, audio, and protein generation.
- CLIP-style contrastive models — train a text encoder and image encoder so that matched (caption, image) pairs are close in shared embedding space. Powers zero-shot classification and the text conditioning of modern generators.
Drills
D1 · Number of ViT tokens
Patch size 16, image 224x224. How many patch tokens? Plus the CLS token, how many positions does self-attention see?
Solution
(224/16)^2 = 14*14 = 196 patches. With [CLS]: 197
positions.
D2 · Why diffusion uses a U-Net, not a plain transformer
One sentence.
Solution
The denoising target is a noise tensor at full image resolution; U-Net's skip connections preserve high-resolution spatial detail while still letting deep bottleneck layers reason globally. (Recent work — DiT — does swap in a transformer, at higher compute cost.)
D3 · VAE latent
Why is the VAE encoder constrained to output a distribution (mean + log-variance) rather than a point?
Solution
Sampling from a continuous distribution and adding a KL term to a prior produces a smooth,
interpolatable latent space — slight perturbations of z decode to small image changes.
A deterministic auto-encoder has no such guarantee.
Common pitfalls
- Forgetting positional information. Without positions, the transformer treats input as a bag of tokens.
- Wrong attention mask. For causal models you need a strictly lower-triangular mask; for padded batches you need to mask out padding tokens; cross-attention needs the encoder padding mask, not the decoder causal mask.
- Numerical instability in softmax. Always divide by
sqrt(d_k)before softmax to avoid saturation. - Memory blows up. Attention is O(n^2) in sequence length. Long sequences need tricks (sparse attention, sliding window, FlashAttention).
- Overfitting on small fine-tuning datasets. Use small learning rate (1e-5 to 5e-5), low number of epochs, warmup schedule, and consider PEFT (LoRA) to constrain capacity.
Checkpoint — answer out loud
- Can you write scaled dot-product attention from memory and explain every term?
- Can you compute the parameter count of a multi-head attention layer given
d_modelandn_heads? - Can you sketch a transformer block (pre-LN) including residuals?
- Can you pick encoder-only vs decoder-only vs encoder-decoder for a new task and justify in one sentence?
- Can you describe in one sentence each what a ViT, a U-Net, a VAE, and a DDPM do?
Next step
With the architecture stack complete, head to Problems to apply it under contest conditions, then Mocks for the full timed experience. For paper short-answer reps on scaled dot-product, positional encoding, multi-head cost, and KV-cache math, hit the Round 2 theory drills.