Attention & transformers

Tokenization, embeddings, scaled dot-product attention, multi-head attention, the transformer block, and how the same architecture powers modern NLP, vision (ViT), and even graph and audio models.

Why this matters. Transformers are the substrate of nearly every modern AI system — LLMs, image generators, speech recognition, code assistants. Understanding attention is no longer optional.
Syllabus link. Mapped to the modern-architectures block of the official USAAIO syllabus.
TL;DR. By the end of this page you should be able to (1) tokenize text and embed tokens into vectors with positional information; (2) derive scaled dot-product attention and explain every term; (3) build a multi-head attention layer and a transformer block from nn.Linear + nn.LayerNorm; (4) distinguish encoder-only (BERT), decoder-only (GPT), and encoder-decoder (T5) and pick one for a given task; (5) describe how ViTs, U-Nets, VAEs, and diffusion models reuse or extend the same building blocks for vision and generation. USAAIO transformer questions test concept fluency, shape arithmetic, and small implementations.

1. Tokenization

Concept

Models don't see text — they see integer token IDs. The tokenizer converts strings to IDs and back. Three strategies:

Worked example — BPE merges by hand

Starting vocabulary {l, o, w, e, r, n} and corpus {"low": 5, "lower": 2, "newer": 6}. Each word is initially split into characters with a special end-of-word marker (omitted here for brevity).

from collections import Counter

# Minimal BPE training loop sketch (do NOT use in production — just illustrates the idea)
def get_pairs(word_freqs):
    pairs = Counter()
    for word, freq in word_freqs.items():
        symbols = word.split()
        for a, b in zip(symbols, symbols[1:]):
            pairs[(a, b)] += freq
    return pairs

def merge(pair, word_freqs):
    new = {}
    bigram = " ".join(pair)
    repl   = "".join(pair)
    for w, f in word_freqs.items():
        new[w.replace(bigram, repl)] = f
    return new

vocab = {"l o w </w>": 5, "l o w e r </w>": 2, "n e w e r </w>": 6}
for _ in range(4):
    pairs = get_pairs(vocab)
    best  = max(pairs, key=pairs.get)
    vocab = merge(best, vocab)
print(vocab)

Drills

D1 · Why subword beats word

Give two concrete advantages of BPE over word-level tokenization.

Solution

(1) No OOV: any new word becomes a sequence of known subwords. (2) Morphological generalisation: "running" -> run + ning shares the "run" embedding with "runs" and "runner".

D2 · Sequence-length cost

A 1 000-character English sentence becomes ~200 BPE tokens; the same sentence is ~5 000 byte-level tokens. Why does this matter for compute?

Solution

Attention is O(n^2) in sequence length. A 25x sequence length is a 625x attention cost. Subword tokenization is largely a compute optimisation.

2. Embeddings & positional encodings

Concept

Each token ID is looked up in an embedding table E in R^(V x d) to produce a dense vector (typically d = 256 to 4096). Embeddings are learned during training; semantically similar tokens end up nearby in vector space.

Self-attention is permutation-equivariant — it doesn't know token order. Position is injected explicitly:

Worked example — embedding + sinusoidal PE

import torch, torch.nn as nn, math

class SinusoidalPE(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)         # not a parameter

    def forward(self, x):                       # x: (B, T, d_model)
        return x + self.pe[:x.size(1)]

vocab_size, d_model = 32000, 512
emb = nn.Embedding(vocab_size, d_model)
pos = SinusoidalPE(max_len=2048, d_model=d_model)
ids = torch.tensor([[1, 5, 9, 4]])
x   = pos(emb(ids))                            # (1, 4, 512)
print(x.shape)

Drills

D1 · Why scale-invariant frequencies?

Why does sinusoidal PE use frequencies geometric in i (factor 10000)?

Solution

Geometric frequencies give the network a multi-scale "clock" — short frequencies encode local position, long frequencies encode coarse position. Differences of positions map linearly to phase differences, so the model can learn relative offsets easily.

D2 · Why not just one-hot positions?

One sentence.

Solution

One-hot positions can't generalise to lengths unseen at training time and use far more dimensions; smooth periodic or rotary encodings interpolate naturally.

3. Scaled dot-product attention

Concept

Given Queries Q in R^(T x d_k), Keys K in R^(T x d_k), Values V in R^(T x d_v) (typically d_k = d_v):

Attention(Q, K, V) = softmax( Q K^T / sqrt(d_k) ) V

For causal (decoder) attention, mask the upper triangle with -inf before softmax so position t can only attend to positions <= t. For padded batches, mask out padding tokens.

Worked example — attention from first principles

import torch, torch.nn.functional as F

def attention(Q, K, V, mask=None):
    # Q, K, V: (batch, n_heads, T, d_head)
    d_k = Q.size(-1)
    scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ V, weights

# Tiny sanity check: identical Q and K -> each position attends mostly to itself
T, d = 4, 8
Q = K = V = torch.randn(1, 1, T, d)
out, w = attention(Q, K, V)
print(w[0,0])     # diagonal should dominate

Drills

D1 · Why divide by sqrt(d_k)?

What concretely goes wrong if you omit the scaling?

Solution

For independent Q, K with unit variance, each entry of Q K^T has variance d_k. As d_k grows, the softmax becomes extremely peaked (effectively argmax), so gradients vanish for non-selected positions. Dividing by sqrt(d_k) keeps the variance ~ 1.

D2 · Causal mask shape

For sequence length T = 4, write the causal mask as a 4x4 matrix of 1s and 0s where 1 means "allowed".

Solution
1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

Lower triangular including the diagonal. Apply by setting positions where the mask is 0 to -inf in the score matrix before softmax.

D3 · Attention complexity

Sequence length T = 4096, d_model = 512. How does compute scale with T? With d_model?

Solution

O(T^2 * d_model) for the Q K^T and softmax * V products. Doubling T quadruples compute; doubling d_model doubles it. Long-sequence tricks (FlashAttention, sliding window, sparse attention) attack the T^2 term.

4. Multi-head attention

Concept

Instead of one attention with full d_model dimensions, split into h heads of size d_head = d_model / h. Each head has its own learned W_Q, W_K, W_V and can attend to a different relationship (e.g. one head tracks subject-verb agreement, another tracks coreference). Outputs are concatenated and projected by W_O.

Worked example — multi-head attention module

import torch.nn as nn
# attention() from above

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, D = x.shape
        # project and reshape to (B, n_heads, T, d_head)
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        out, _ = attention(Q, K, V, mask)             # (B, n_heads, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out)

Drills

D1 · Param count of MHA

For d_model = 512, n_heads = 8, how many parameters in the four projection matrices (ignoring biases)?

Solution

Each of W_Q, W_K, W_V, W_O is 512 x 512 = 262 144. Total 4 * 262 144 = 1 048 576 = ~1.0 M parameters (plus 4*512 biases).

D2 · Why split into heads?

If the total parameter count is the same as one big attention with d_head = d_model, what's the point of splitting?

Solution

Each head has its own subspace and its own softmax — it can specialise. Empirically, h heads each attending to a different relation outperform one large head, while computing the same total work.

5. The transformer block

Concept

A transformer block applies, in order: LayerNorm -> multi-head self-attention -> residual; LayerNorm -> position-wise feed-forward MLP -> residual. The MLP is typically 2 linear layers with a 4x widening factor and a GELU activation. Residual connections are critical — they let the model add small refinements at each layer and keep gradients flowing through deep stacks.

Two orderings exist: post-LN (original, harder to train deeply) and pre-LN (modern default, trains stably to hundreds of layers). The example below uses pre-LN.

Worked example — encoder block

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn  = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.drop  = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.drop(self.attn(self.norm1(x), mask))
        x = x + self.drop(self.ff(self.norm2(x)))
        return x

Drills

D1 · LayerNorm vs BatchNorm

Why do transformers use LayerNorm, not BatchNorm?

Solution

BatchNorm normalises across the batch dimension; for variable-length sequences and small batches (common in NLP) the statistics are unreliable. LayerNorm normalises within a single sample, across the feature dimension — independent of batch size and sequence padding.

D2 · Residual + LN math

With pre-LN, what tensor enters the attention layer: x or LayerNorm(x)?

Solution

LayerNorm(x). The residual added back is the un-normalised x. This ordering keeps the residual stream's scale stable across many layers.

6. Pre-training, fine-tuning, prompting

Concept

Transformers achieve their power by being pre-trained on massive unlabeled corpora with self-supervised objectives, then adapted to downstream tasks.

7. Encoder vs decoder vs encoder-decoder

ArchitectureAttentionExamplesBest for
Encoder-onlyBidirectional self-attnBERT, RoBERTa, DeBERTaClassification, NER, retrieval, embeddings
Decoder-onlyCausal self-attnGPT, LLaMA, MistralGeneration, chat, code, in-context learning
Encoder-decoderEncoder self-attn + decoder causal self-attn + cross-attnT5, BART, WhisperTranslation, summarisation, ASR

Cross-attention in encoder-decoder models is identical to self-attention except Q comes from the decoder's current state while K and V come from the encoder's outputs.

Drills

D1 · Pick the architecture

(a) Translate French to English. (b) Detect toxic comments. (c) Build a chatbot. (d) Score sentence similarity for search.

Solution

(a) encoder-decoder (T5/BART); (b) encoder-only (BERT) + linear head; (c) decoder-only (GPT-style); (d) encoder-only with [CLS] pooling or a sentence-encoder like Sentence-BERT.

D2 · Cross-attention shape

Decoder generating token t with state d_t in R^d; encoder outputs H in R^(S x d). What are the shapes of Q, K, V in cross-attention?

Solution

Q has shape (1, d) projected from d_t; K, V have shape (S, d) projected from H. The decoder reads a weighted summary of encoder outputs.

8. Vision transformers & generative models

Concept

Vision Transformer (ViT): chop a 224x224 image into 196 patches of 16x16, flatten each patch to a 768-dim vector via a linear projection, prepend a learned [CLS] token, add positional embeddings, and run a stack of transformer encoder blocks. At ImageNet-21k scale and above, ViTs beat ResNets; below that, CNNs win because they bake in translation equivariance as an inductive bias.

Generative & CV applications (per the syllabus):

Drills

D1 · Number of ViT tokens

Patch size 16, image 224x224. How many patch tokens? Plus the CLS token, how many positions does self-attention see?

Solution

(224/16)^2 = 14*14 = 196 patches. With [CLS]: 197 positions.

D2 · Why diffusion uses a U-Net, not a plain transformer

One sentence.

Solution

The denoising target is a noise tensor at full image resolution; U-Net's skip connections preserve high-resolution spatial detail while still letting deep bottleneck layers reason globally. (Recent work — DiT — does swap in a transformer, at higher compute cost.)

D3 · VAE latent

Why is the VAE encoder constrained to output a distribution (mean + log-variance) rather than a point?

Solution

Sampling from a continuous distribution and adding a KL term to a prior produces a smooth, interpolatable latent space — slight perturbations of z decode to small image changes. A deterministic auto-encoder has no such guarantee.

Common pitfalls

Checkpoint — answer out loud

Next step

With the architecture stack complete, head to Problems to apply it under contest conditions, then Mocks for the full timed experience. For paper short-answer reps on scaled dot-product, positional encoding, multi-head cost, and KV-cache math, hit the Round 2 theory drills.