Saral Shiksha Yojna
Courses/Computer Vision

Computer Vision

CSE471
Prof. Makarand Tapaswi + Prof. Charu SharmaSpring 2025-264 credits

Modern Transformer Upgrades

Intuition

Original Transformer (2017) → modern LLM (2024+) is roughly seven independent upgrades. Each fixes a specific failure mode of the original — training stability, attention sharpness, position generalisation, or inference efficiency.

Explanation

PreNorm vs PostNorm. Original: y = LN(x + Sublayer(x)); the residual stream is normalised after the addition. Modern: y = x + Sublayer(LN(x)); the residual stream is unbroken — there's a direct identity path from input to output. Unbroken residual → gradients flow more directly with depth → stable training for very deep models. PostNorm models typically need careful warmup; PreNorm trains stably out of the box.

LayerNorm vs RMSNorm. LayerNorm: y = γ · (x − μ)/σ + β; centers and scales. RMSNorm: y = γ · x / RMS(x); only scales. Geometric picture: LayerNorm translates the cloud's centroid to origin then rescales onto a sphere; RMSNorm only rescales, keeping each point's direction. RMSNorm is cheaper (no mean computation) and empirically the mean-centering step adds little — Llama and T5 use it.

LayerScale. After each sub-layer, multiply by a learnable diagonal per-channel scale initialised at ~1e-4: x + diag(γ_l) · Sublayer(LN(x)). Initially each sub-layer contributes nearly nothing; γ_l grows during training. Fixes the instability of training very-deep ViTs (24+ layers).

QK-Norm. Apply LayerNorm (or RMSNorm) to Q and K separately before computing QKᵀ. Prevents attention-logit explosion: without it, in large models QKᵀ can grow uncontrollably, saturating softmax and producing NaNs. QK-Norm bounds Q and K magnitudes.

Registers (Darcet et al.). Extra learnable tokens with no spatial meaning prepended to the input sequence. Trained ViTs concentrate attention on uninformative patches (blank sky) as 'scratch space' for global info — pollutes the attention. Registers give dedicated scratchpads → real patches stay focused on content → cleaner attention maps + small but consistent quality gains.

RoPE (Rotary Position Embeddings). Encode position by rotating Q and K by an angle proportional to position: for token at position m, rotate the (2i, 2i+1) pair by m · θᵢ where θᵢ = 1/10000^(2i/d). Key property: dot product after rotation depends only on relative position (m − n), not absolute. Generalises to longer sequences than seen at training time. Used in Llama, GPT-NeoX, Qwen, etc.

KV-cache + Flash Attention + GQA — three independent inference upgrades. KV-cache: at autoregressive step t, store K and V of past tokens; only compute Q for the new token + new K, V (appended). O(t²) → O(t) per step. Flash Attention: tile Q, K, V into blocks small enough to fit in SRAM; compute attention block-by-block with online softmax; never materialise the full N × N matrix in HBM. 2-4× faster, much lower memory, mathematically identical. GQA: divide h heads into G groups; heads in a group SHARE K, V projections (only Q is per-head). Reduces KV-cache by h/G. Multi-Query Attention (MQA) = G = 1 extreme; GQA is the quality-friendly compromise used in Llama 2/3 and Mistral.

Definitions

  • PreNormy = x + Sublayer(LN(x)). Unbroken residual stream; stable for deep models.
  • RMSNormy = γ · x / RMS(x). No mean subtraction; cheaper than LayerNorm.
  • LayerScalePer-channel scale γ_l initialised at ~1e-4; lets each sub-layer 'opt in' gradually.
  • QK-NormApply LayerNorm/RMSNorm to Q and K before QKᵀ; prevents attention-logit explosion.
  • RegistersExtra learnable tokens with no position; absorb scratchpad usage that would otherwise pollute attention.
  • RoPERotary positional encoding via per-pair rotation by m·θᵢ; dot product encodes relative position.
  • Flash AttentionIO-aware tiled attention; computes attention block-by-block in SRAM; never materialises N×N matrix in HBM.
  • GQAGroup-Query Attention; heads in a group share K, V projections; reduces KV-cache by h/G.
  • KV-cacheCache of K, V for past tokens; per-step compute drops from O(t²) to O(t).

Formulas

  • \text{PreNorm:}\ y = x + \text{Sublayer}(\text{LN}(x))
  • \text{RMSNorm:}\ y = \gamma \cdot x / \sqrt{\tfrac{1}{d}\sum_i x_i^2}
  • \text{LayerScale:}\ y = x + \text{diag}(\gamma_l) \cdot \text{Sublayer}(\text{LN}(x))
  • \text{RoPE:}\ [q'_{2i}, q'_{2i+1}] = R_{m\theta_i} [q_{2i}, q_{2i+1}],\ \theta_i = 1/10000^{2i/d}
  • \text{GQA KV-cache reduction factor:}\ h / G

Derivations

RoPE preserves relative position: after rotating q by m·θ and k by n·θ, qᵀk (after rotation) is a function of (m − n)·θ — invariant to absolute positions. This is why RoPE generalises to longer sequences than seen at training.

Why PreNorm is more stable: under PreNorm, the variance of the residual stream is bounded as it passes through depth — each sub-layer's output is normalised before addition. Under PostNorm, variance can blow up unless very careful warmup is used.

Examples

  • Llama 3 stack: PreNorm + RMSNorm + RoPE + GQA + Flash Attention. No LayerScale (decoder-only LLMs are typically not as deep as some ViT variants).
  • GQA in Llama-2-70B: 64 heads grouped into 8 groups → KV-cache shrinks 8× with negligible quality loss.
  • Flash Attention v2 reduces memory from O(N²) to O(N) for large N (e.g., 32 k context).

Diagrams

  • PreNorm block vs PostNorm block: residual stream highlighted; PreNorm has it unbroken.
  • LayerNorm vs RMSNorm geometric: ND points → LayerNorm centers + rescales; RMSNorm only rescales.
  • GQA visualisation: 8 heads → 4 groups of 2; heads in a group share K, V.

Edge cases

  • PreNorm output magnitude can drift; final LayerNorm before the head is still required.
  • Flash Attention requires GPU support (Ampere+ for v1, Hopper+ for v3).
  • GQA with very few groups (MQA limit) loses quality on some tasks — GQA is the sweet spot.

Common mistakes

  • Calling RMSNorm 'LayerNorm without bias' — also drops the mean subtraction.
  • Treating RoPE as additive — it's multiplicative (rotation matrix applied to Q, K).
  • Conflating Flash Attention with KV-cache — they solve different problems (memory layout vs incremental compute).
  • Stating Registers are positional — they have no positional encoding, that's the point.

Shortcuts

  • Seven modern upgrades to memorise: PreNorm, RMSNorm, LayerScale, QK-Norm, Registers, RoPE, GQA + KV-cache + Flash Attention.
  • RoPE is multiplicative (rotation); old PE is additive.
  • KV-cache memory grows linearly with sequence length — long-context inference is bottlenecked here.

Proofs / Algorithms