Attention variants — MQA, GQA, FlashAttention, linear & sparse
Standard scaled dot-product attention is the workhorse of every transformer, but it has
two ugly cost terms: O(n^2 · d) compute and an O(n · h · d)
KV cache at inference. A whole zoo of variants — MQA, GQA, FlashAttention, linear
attention, sliding-window and sparse attention — exists to attack one or both of those
terms without breaking the inductive bias too badly.
O(n^2 · d) compute and
materialises an n x n score matrix. FlashAttention keeps
the exact same math but tiles the computation so the score matrix never hits HBM,
making long-context training 2-4x faster with no quality loss. MQA / GQA
share K and V across heads to shrink the KV cache during autoregressive decoding —
Llama-2 70B uses GQA-8. Linear attention rewrites
softmax(QK^T)V as φ(Q)(φ(K)^T V) for O(n · d^2)
cost; sliding-window and sparse attention drop most
of the score matrix entries. In IOAI/USAAIO tasks you meet these whenever a long
document, a long audio sequence, or a memory-tight inference budget shows up.
1. The intuition
Standard self-attention computes a similarity score between every pair of tokens.
For a sequence of length n that's n^2 scores, each costing
d multiplications — so compute is O(n^2 · d) and memory is
O(n^2). At n = 8k the score matrix alone is 64M entries per
head per layer; at n = 32k it stops fitting in GPU memory at all.
Variants attack the cost along two roughly orthogonal axes. The first
is exact-but-cheaper: FlashAttention reorders the computation so the
n x n score block is never written to high-bandwidth memory (HBM); the
answer is bit-identical to vanilla attention, just faster and lower-memory. The second
is approximate-or-restricted: linear attention swaps softmax for a kernel
feature map so you can reassociate the matmul; sliding-window attention only attends
to a local neighbourhood; sparse attention picks a structured subset of token pairs.
These change the function the model can express.
A third concern lives entirely at inference time: the KV cache. When
a transformer decodes one token at a time, every previous K and V tensor must be
kept around. For an L-layer model with h heads and head
dim d_h the cache is 2 · L · n · h · d_h floats — for
Llama-2 70B at 4k context that's tens of GB. MQA and GQA cut this directly by
sharing K and V across heads.
2. The math
Scaled dot-product attention
Let Q, K, V in R^(n x d_k) (we'll fold the batch axis in later). The core
operation is:
The 1/sqrt(d_k) factor keeps the pre-softmax logits at unit variance when
Q and K have unit-variance entries; without it softmax saturates and gradients
vanish. Compute cost: the matmul Q K^T is n · n · d_k, and
S V is another n · n · d_v, so total
O(n^2 · d). Memory: the score matrix S alone is
n^2 entries.
Multi-head attention
Split the model dim d_model into h heads of size
d_h = d_model / h. Each head has its own (W_Q^i, W_K^i, W_V^i)
and runs attention in parallel; outputs concatenate and project through W_O:
Total compute is unchanged — splitting into heads is just a tensor reshape. Storage
of K and V for inference, however, is 2 · n · h · d_h per layer.
Multi-Query Attention (MQA)
Single K, V shared across all query heads (Shazeer, 2019). Q still has
h heads, but W_K, W_V project to a single
d_h-dim space:
KV cache shrinks from 2 · n · h · d_h to 2 · n · d_h — a
factor of h smaller (typically 16-32x). Decoding throughput goes up
proportionally because each generated token rereads the cache. Quality drops
slightly vs full MHA because heads can no longer specialise their key/value
projections.
Grouped-Query Attention (GQA)
A middle ground (Ainslie et al., 2023). Partition h query heads into
g groups; each group shares one K, V head:
g = h recovers vanilla MHA; g = 1 recovers MQA. Llama-2 70B
uses h = 64, g = 8 — KV cache 8x smaller than MHA, quality nearly
indistinguishable.
FlashAttention — IO-aware exact attention
Dao et al. (2022). Same math as vanilla attention, but the n x n score
matrix is never materialised in HBM. Tile Q into blocks of size
B_r and K, V into blocks of size B_c; for each
pair of tiles compute a partial softmax in SRAM and accumulate the output using the
online-softmax recurrence:
Where m is the running rowwise max and l the running
rowwise denominator. HBM access drops from O(n^2) to
O(n^2 · d / M) where M is SRAM size, which translates to
2-4x wall-clock speedup at n >= 2k. In PyTorch 2.x you just call
torch.nn.functional.scaled_dot_product_attention and the FlashAttention
backend is picked automatically when the shapes and dtypes are eligible.
Linear / kernelized attention
Replace softmax with a kernel feature map φ : R^d → R^m such that
k(q, k) ≈ φ(q)^T φ(k). Then by associativity:
Compute φ(K)^T V first — an m x d matrix — then multiply
by φ(Q). Total cost O(n · m · d), linear in n.
Linear Transformer (Katharopoulos et al., 2020) uses
φ(x) = elu(x) + 1; Performer (Choromanski et al., 2020) uses random
Fourier features that approximate softmax in expectation. Trade-off: gradient noise
is higher and the model loses some of softmax's sharp attending behaviour.
Sliding-window attention
Each query only attends to w tokens on either side. The score matrix
becomes band-diagonal with bandwidth 2w + 1:
Compute is O(n · w · d). After L stacked layers the
effective receptive field is L · w, so deep stacks recover long-range
dependencies indirectly (used in Longformer, Mistral 7B with w = 4096).
Sparse attention
Sparse Transformer (Child et al., 2019) combines a local band of width
w ~ sqrt(n) with a strided pattern that hops every
sqrt(n) tokens. Two layers of this composition reach any pair, and total
cost is O(n · sqrt(n) · d). BigBird adds a random sparse pattern plus
a few global tokens and proves the result is a universal sequence-to-sequence
approximator.
Rotary positional embedding (RoPE)
Not a complexity variant — RoPE modifies Q and K before the dot product so that
relative position is encoded in the angle of a 2D rotation pair. For each pair of
dimensions (2i, 2i+1) rotate by angle m · θ_i at position
m with θ_i = 10000^(-2i/d). The dot product
q_m · k_n then depends only on m - n. RoPE composes cleanly
with every variant above because it acts pointwise on Q and K.
Cost summary
| Variant | Compute | Score-matrix memory | KV cache | Exact? |
|---|---|---|---|---|
| Vanilla MHA | O(n^2 · d) | O(n^2) | 2 · n · h · d_h | yes |
| FlashAttention | O(n^2 · d) | O(n) | 2 · n · h · d_h | yes |
| MQA | O(n^2 · d) | O(n^2) | 2 · n · d_h | yes (different params) |
| GQA-g | O(n^2 · d) | O(n^2) | 2 · n · g · d_h | yes (different params) |
| Sliding-window w | O(n · w · d) | O(n · w) | 2 · w · h · d_h | no (local only) |
| Sparse (strided) | O(n · sqrt(n) · d) | O(n · sqrt(n)) | same | no |
| Linear / Performer | O(n · m · d) | O(m · d) | O(m · d) state | approx |
3. PyTorch reference implementation
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# 1. Vanilla scaled dot-product attention from scratch
# ---------------------------------------------------------------------------
def vanilla_attention(Q, K, V, mask=None):
"""Q, K, V: (B, h, n, d_h). mask: (B, 1, n, n) additive, 0 or -inf."""
d_h = Q.size(-1)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h) # (B, h, n, n)
if mask is not None:
scores = scores + mask
attn = scores.softmax(dim=-1)
return attn @ V # (B, h, n, d_h)
# ---------------------------------------------------------------------------
# 2. Multi-head / MQA / GQA via a single module (g controls grouping)
# ---------------------------------------------------------------------------
class GroupedQueryAttention(nn.Module):
"""g = h -> standard MHA. g = 1 -> MQA. 1 < g < h -> GQA."""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
assert d_model % n_heads == 0
assert n_heads % n_kv_heads == 0
self.h, self.g, self.d_h = n_heads, n_kv_heads, d_model // n_heads
self.W_Q = nn.Linear(d_model, n_heads * self.d_h, bias=False)
self.W_K = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False)
self.W_V = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, n, _ = x.shape
Q = self.W_Q(x).view(B, n, self.h, self.d_h).transpose(1, 2) # (B,h,n,d_h)
K = self.W_K(x).view(B, n, self.g, self.d_h).transpose(1, 2) # (B,g,n,d_h)
V = self.W_V(x).view(B, n, self.g, self.d_h).transpose(1, 2) # (B,g,n,d_h)
# Repeat K, V across each group's query heads
K = K.repeat_interleave(self.h // self.g, dim=1) # (B,h,n,d_h)
V = V.repeat_interleave(self.h // self.g, dim=1)
# Use the fused FlashAttention backend when available
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask,
is_causal=(mask is None))
out = out.transpose(1, 2).contiguous().view(B, n, self.h * self.d_h)
return self.W_O(out)
# ---------------------------------------------------------------------------
# 3. Sliding-window mask
# ---------------------------------------------------------------------------
def sliding_window_mask(n, w, device=None):
"""Additive mask: 0 where |i - j| <= w, else -inf. Shape (1, 1, n, n)."""
i = torch.arange(n, device=device).unsqueeze(1)
j = torch.arange(n, device=device).unsqueeze(0)
keep = (i - j).abs() <= w
return torch.where(keep, 0.0, float("-inf")).view(1, 1, n, n)
# ---------------------------------------------------------------------------
# 4. Linear attention with feature map phi(x) = elu(x) + 1
# ---------------------------------------------------------------------------
def linear_attention(Q, K, V):
"""O(n * d^2) instead of O(n^2 * d). Q,K,V: (B, h, n, d_h)."""
phi_Q = F.elu(Q) + 1.0
phi_K = F.elu(K) + 1.0
KV = phi_K.transpose(-2, -1) @ V # (B, h, d_h, d_h)
Z = phi_Q @ phi_K.sum(dim=-2, keepdim=True).transpose(-2, -1) # normaliser
return (phi_Q @ KV) / (Z + 1e-6)
# ---------------------------------------------------------------------------
# 5. FlashAttention — conceptual tiled outer loop.
# Real kernels live in CUDA; PyTorch dispatches to them via SDPA.
# ---------------------------------------------------------------------------
def flash_attention_reference(Q, K, V, B_r=64, B_c=64):
"""Numerically equivalent to vanilla_attention, computed tile by tile."""
B, h, n, d_h = Q.shape
scale = 1.0 / math.sqrt(d_h)
O = torch.zeros_like(Q)
l = torch.zeros(B, h, n, 1, device=Q.device)
m = torch.full((B, h, n, 1), float("-inf"), device=Q.device)
for i in range(0, n, B_r):
Qi = Q[:, :, i:i + B_r]
Oi = torch.zeros_like(Qi)
li = torch.zeros(B, h, Qi.size(2), 1, device=Q.device)
mi = torch.full((B, h, Qi.size(2), 1), float("-inf"), device=Q.device)
for j in range(0, n, B_c):
Kj = K[:, :, j:j + B_c]
Vj = V[:, :, j:j + B_c]
Sij = (Qi @ Kj.transpose(-2, -1)) * scale
m_new = torch.maximum(mi, Sij.max(dim=-1, keepdim=True).values)
P = torch.exp(Sij - m_new)
l_new = torch.exp(mi - m_new) * li + P.sum(dim=-1, keepdim=True)
Oi = torch.exp(mi - m_new) * Oi + P @ Vj
mi, li = m_new, l_new
O[:, :, i:i + B_r] = Oi / li
return O
# Use torch.compile in real code to fuse the linear / GQA variants:
# fast_gqa = torch.compile(GroupedQueryAttention(512, 8, 2))
if __name__ == "__main__":
B, n, d_model, h, g = 2, 64, 256, 8, 2
x = torch.randn(B, n, d_model)
block = GroupedQueryAttention(d_model, h, g)
y = block(x)
print(y.shape) # torch.Size([2, 64, 256])
block.train(False) # switch dropout/BN to inference mode
F.scaled_dot_product_attention is the right call in production — it picks
the FlashAttention CUDA backend automatically on Ampere/Hopper GPUs when dtypes are
fp16/bf16 and the shapes are eligible, and falls back to a memory-efficient kernel
otherwise. The hand-rolled flash_attention_reference above is for
teaching only; it is not faster than vanilla because Python loops can't beat a fused
CUDA kernel.
4. Common USAAIO / IOAI applications
- Long-context summarisation / QA over articles or transcripts (IOAI 2024 NLP style) — sliding-window or sparse attention lets a 4k-trained model process 32k tokens; FlashAttention makes the training feasible on a single GPU.
- Fast inference under a token budget — MQA/GQA cut KV cache memory and decoding latency. If you fine-tune Llama-2 or Mistral for a contest task and need to generate many answers in time-limited evaluation, GQA matters.
- Audio / spectrogram transformers — sequences are long (10k+ frames) and locally structured, perfect for sliding-window attention.
- Genomic / DNA models (IOAI-style biology tasks) — sequences of millions of bases. Linear or sparse attention is the only option.
- Low-VRAM training in the lab/Kaggle 16GB GPU world — FlashAttention drops activation memory enough to fit longer sequences or larger batch sizes without algorithmic changes.
5. Drills
D1 · Prove standard attention is O(n^2 · d)
Given Q, K, V in R^(n x d), count the multiplications and additions in
softmax(Q K^T / sqrt(d)) V.
Solution
Q K^T is two matrices of shape (n, d) and
(d, n) — the result is n x n with each entry costing
d multiplies and d - 1 adds, so O(n^2 · d).
Softmax is O(n^2). Multiplying the n x n attention by
V in R^(n x d) is another O(n^2 · d). Total
O(n^2 · d) compute and O(n^2) peak memory for the score
matrix.
D2 · MQA KV-cache savings
A model has L = 32 layers, h = 32 heads, head dim
d_h = 128, sequence length n = 4096, fp16. What is the KV
cache size for MHA vs MQA vs GQA-4?
Solution
Cache = 2 · L · n · h_kv · d_h · 2 bytes.
- MHA:
2 · 32 · 4096 · 32 · 128 · 2 = 2.0 GB - GQA-4:
2 · 32 · 4096 · 4 · 128 · 2 = 256 MB(8x smaller) - MQA:
2 · 32 · 4096 · 1 · 128 · 2 = 64 MB(32x smaller)
On a 24 GB consumer GPU this is the difference between fitting one or 8+ concurrent sequences.
D3 · Sliding-window effective receptive field
A 12-layer transformer uses sliding-window attention with w = 128.
What is the maximum distance between two tokens that can directly influence each
other in the output?
Solution
Each layer extends reach by w tokens on each side, so after
L layers reach is L · w. Here
12 · 128 = 1536 tokens. That's enough for paragraph-level dependencies
but not document-level. Mistral 7B uses w = 4096 over 32 layers for
reach ~131k, which is why it claims a 32k context.
D4 · Linear-attention kernel choice
You try linear attention with φ(x) = x (the identity). Training loss
diverges. Why? What does φ(x) = elu(x) + 1 fix?
Solution
For the attention output to be a convex combination of values, the implicit
similarity φ(q)^T φ(k) must be non-negative; otherwise
the normaliser Z = φ(Q) · Σ φ(K) can be zero or negative and the
output blows up. The identity map produces signed dot products. The
elu(x) + 1 map is always strictly positive (range (0, ∞))
while staying smooth and gradient-friendly, which guarantees a positive denominator
and stable training.
D5 · When NOT to use FlashAttention
Your model trains fine with vanilla attention. You switch to FlashAttention and accuracy drops a little. What likely happened?
Solution
FlashAttention is bit-equivalent to vanilla in exact arithmetic, but reductions happen in a different order, so fp16/bf16 rounding differs slightly. If you trained with vanilla and only switched at fine-tune time, those rounding shifts can move logits enough to matter. Either retrain end-to-end with FlashAttention, or upcast the softmax to fp32 inside the kernel (PyTorch's SDPA does this by default in newer versions). It's not the algorithm — it's the numerics.
Next step
Loop back to Transformers for the base architecture that all of these variants modify, and to Round 2 theory for short-answer drills on attention complexity, KV cache sizing, and RoPE. If you're building a long-context model end to end, the Deep Learning page covers the training-loop and mixed-precision plumbing you'll need underneath.