Attention Mechanism & Transformer Architecture
Intuition
Attention is a soft, learnable dictionary lookup. The decoder asks a question (Q); each position in the source provides a key (K); the value (V) at the matching position is returned. Multi-head attention runs many such lookups in parallel, each learning a different relationship.
Explanation
Pre-attention seq2seq RNNs compressed the entire source into a single fixed-length hidden vector, then unrolled the decoder from that bottleneck. For long sentences the model could not revisit specific source positions. Bahdanau attention let the decoder, at each output step, take a weighted sum over all encoder hidden states with weights that depend on the current decoder state — effectively a content-addressable memory.
Transformers replace recurrence entirely: every token attends to every other in one shot. Three reasons to prefer them: (1) parallelisation — RNNs are inherently sequential, Transformers process all tokens simultaneously; (2) constant path length between any two positions (good gradient flow vs RNN's vanishing problem); (3) richer modelling via stacked self-attention. Trade-off: O(n²) memory in sequence length.
Scaled dot-product attention: Attn(Q, K, V) = softmax(QKᵀ / √dₖ) V. The √dₖ rescales variance — for random Q, K with unit-variance components, the dot product has variance dₖ, which would saturate the softmax and kill gradients without scaling.
Multi-head attention projects Q/K/V into h subspaces of dimension dₖ = d_model / h, runs scaled dot-product attention per head, concatenates outputs, and applies a final linear projection. Different heads learn different relationships (syntactic dependency, semantic anaphora, positional locality).
Encoder layer: self-attention then FFN, each with Add & Norm. Decoder layer: MASKED self-attention (look-ahead mask sets future-position logits to −∞), then cross-attention (Q from decoder, K, V from encoder output), then FFN — each with Add & Norm. Cross-attention is what lets the decoder query the source.
Self-attention is permutation equivariant, so positional encoding must be added. The original sinusoidal scheme PE(pos, 2i) = sin(pos / 10000^{2i/d}), PE(pos, 2i+1) = cos(…) uses many frequencies so the model can recover both fine and coarse position; it generalises to longer sequences than seen at training time.
Show-Attend-and-Tell is the image-captioning precursor: image → CNN → grid of L spatial features X ∈ ℝ^{L×D}; at each LSTM step compute attention weights over L locations conditioned on h_{t-1} → context vector z_t; LSTM update h_t = LSTM(z_t, y_t, h_{t-1}). The decoder learns to 'look' at different image regions for different words — but the attention reveals only where the model looks, not whether it sees correctly (adversarial mislabelled colours expose this).
Definitions
- Q / K / V — Query / Key / Value — three projections of the same (self) or different (cross) input.
- Self-attention — Q, K, V all from the same sequence.
- Cross-attention — Q from one sequence (decoder), K, V from another (encoder).
- Look-ahead mask — Upper-triangular mask sets future-position logits to −∞ in the decoder's first sub-layer.
- PostNorm vs PreNorm — PostNorm: LN(x + sublayer); needs warmup. PreNorm: x + sublayer(LN(x)); modern, stable for deep nets.
Formulas
\text{Attn}(Q,K,V) = \text{softmax}\!\left(\tfrac{QK^\top}{\sqrt{d_k}}\right) V\text{MHA}(Q,K,V) = \text{Concat}(h_1,\dots,h_H) W^O,\ h_i = \text{Attn}(QW_i^Q, KW_i^K, VW_i^V)\text{PE}(pos, 2i) = \sin\!\left(pos / 10000^{2i/d}\right);\ \ \text{PE}(pos, 2i+1) = \cos\!\left(pos / 10000^{2i/d}\right)y = \text{LayerNorm}(x + \text{Sublayer}(x))\ \ \text{(Post-Norm)}y = x + \text{Sublayer}(\text{LayerNorm}(x))\ \ \text{(Pre-Norm; modern)}
Derivations
Why divide by √dₖ: assume Q, K ∈ ℝ^{n × dₖ} with i.i.d. unit-variance entries. The dot product (QKᵀ)_{ij} = Σ_k Q_{ik} K_{jk} has variance dₖ. With dₖ = 64, std is 8 — large entries push softmax into the saturated regime where gradients vanish. Dividing by √dₖ rescales variance back to 1, putting softmax in its useful regime.
Number of parameters in a Transformer layer (base model, d = 512, d_ff = 2048): QKVO projections ≈ 4d² = 1 M; FFN ≈ 2·d·d_ff = 2 M; LayerNorm ≈ 4d (negligible). Per layer ~3 M, × 6 layers ≈ 18 M (encoder), similar for decoder.
Examples
- Cross-attention in machine translation: when generating 'le' in French for English source 'the', the decoder query 'le' attends most strongly to the K position for 'the'. The corresponding V is the encoder's contextual representation of 'the'.
- Look-ahead mask example for target [BOS y₁ y₂ y₃]: a 4×4 mask with −∞ on the upper triangle (excluding diagonal) before softmax.
Diagrams
- Single Transformer encoder block (PreNorm): x → LN → MHA → +residual → LN → FFN → +residual.
- Decoder block: masked MHA → cross-attention → FFN, each with Add & Norm.
- Multi-head attention internals: parallel heads with their own W_Q, W_K, W_V; concat then W_O.
Edge cases
- O(n²) memory in sequence length — long-context inference is bottlenecked by attention, not FFN.
- Without positional encoding, self-attention is order-blind — outputs are a permutation of inputs.
- Cross-attention quality depends on the encoder output — a bad encoder bottlenecks the decoder.
Common mistakes
- Forgetting the look-ahead mask in the decoder's first sub-layer — leaks future tokens.
- Adding positional encoding AFTER the encoder instead of to the input embedding.
- Writing 'multi-head attention has h × more parameters than single head' — total parameter count is roughly identical.
Shortcuts
- Sub-layers per encoder = 2 (MHA, FFN). Per decoder = 3 (masked MHA, cross-attn, FFN).
- d_ff = 4 · d_model in the base transformer.
- Sinusoidal PE generalises to unseen lengths; learned PE does not (in the trivial implementation).