Saral Shiksha Yojna
Courses/Computer Vision

Computer Vision

CSE471
Prof. Makarand Tapaswi + Prof. Charu SharmaSpring 2025-264 credits

Modern Transformer Upgrades

NotesStory
Unit 10 — Transformer Advances (ViT-5 era)

The Architecture Refines Itself

The Transformer paper — *Attention Is All You Need*, Vaswani et al. 2017 — is one of the most consequential pieces of computer science of the last decade. The architecture it introduced is, broadly speaking, still the architecture every state-of-the-art LLM and ViT runs on today.

But *"broadly speaking"* hides a lot of work. Almost every component of the 2017 Transformer has been quietly replaced between then and now. Not torn down — refined. The shape is the same. The internals are different in nine specific ways, and this unit walks through all nine because that's exactly what makes a 2025 ViT-5 different from a 2017 ViT.

The slide pack is named "ViT-5: Vision Transformers for The Mid-2020s", and that's the framing: same skeleton, modernised organs.

We'll walk through the changes grouped into four themes: stability, normalisation, efficiency, and positional encoding.

Why does any of this matter?

When you train a 100-billion-parameter Transformer on a trillion tokens, every one of these is the difference between *"it diverges after 50k steps"* and *"it converges to GPT-4."* These are not "nice-to-haves." They are the engineering reality of modern Transformers.

Change #1 — Post-Norm → Pre-Norm

The 2017 original placed LayerNorm *after* the residual addition. This is Post-Norm:

Modern Transformers use Pre-Norm — LayerNorm goes *before* the sublayer, inside the residual branch:

Why does this matter? In Pre-Norm, the residual stream flows directly from input to output, *untouched by any normalisation*. This direct path is called the residual stream, and it's the highway gradients use to flow back to early layers.

In Post-Norm, every residual gets squashed by a LayerNorm before passing on — gradients have to push through a normaliser at every layer, and very deep Transformers (24+ layers) become unstable to train without careful warmup.

Exam line: *Pre-Norm improves training stability by giving gradients a direct, unnormalised path through the residual stream.*

GPT, LLaMA, PaLM, Gemma — all Pre-Norm.

Change #2 — LayerNorm → RMSNorm

LayerNorm — the original

Centres the data, rescales, applies learnable affine. Two learnable params per dim ().

RMSNorm — drop the mean

Two differences from LayerNorm:

  • No mean subtraction. Data is not recentred.
  • **No bias .** Only a gain .

Slightly faster, empirically as good or better. LLaMA, Gemma, T5 use RMSNorm.

The geometric quiz (slide 7)

Imagine a 2D scatter of points .

LayerNorm:

1. *Centring* (): every point's projection onto the all-ones vector becomes zero — every point lands on the line (the anti-diagonal). 2. *Scaling* (): every centred point goes to unit distance from origin on that line. All points end up at one of two symmetric points or .

In dimensions: LayerNorm projects every vector onto the hyperplane orthogonal to the all-ones vector (zero-mean subspace), then normalises to unit length within that hyperplane. Output lives on a -dimensional sphere within an affine subspace.

RMSNorm:

1. No centring. 2. Scaling: every point rescaled to unit RMS norm — equivalently, projected onto the full -dimensional sphere of radius .

The exam picture:

LayerNorm: points → small circle on the -dim zero-mean hyperplane.
RMSNorm: points → full -dim sphere of radius .

LayerNorm removes both magnitude AND mean (2 constraints). RMSNorm removes only magnitude (1 constraint).

Why does dropping the mean step not hurt? The LLM has plenty of other ways to centre activations — bias terms in linear layers, learnable , etc. It just works, and it's cheaper.

Change #3 — LayerScale

In very deep Transformers, residual contributions from each sublayer can explode or vanish. LayerScale (Touvron et al., CaiT, 2021) inserts a per-channel learnable diagonal scaling:

, initialised tiny (). At the start of training, every sublayer contributes almost nothing — the network behaves like a stack of identity functions, gradients flow trivially. As training proceeds, useful directions in grow.

Exam line: *LayerScale gives each sublayer a learnable "volume knob" per channel, initialised near zero so training begins from a near-identity network.*

Change #4 — QK-Norm

Attention has a hidden numerical hazard: can grow large (especially at long sequences or with large head dims), pushing softmax into saturation — one logit dominates and gradients vanish.

The original scaling is approximately correct but not enough at scale.

QK-Norm (Henry et al., 2020) applies LayerNorm (or RMSNorm) to and independently before the dot product:

and are unit-norm along their head dim, so dot products are bounded by per pair. The softmax sees a well-conditioned input — no extreme values, no saturation, no exploding gradients.

Change #5 — Registers ("garbage collector" tokens)

This is a 2023 ViT paper (Darcet et al., *Vision Transformers Need Registers*) and one of the most surprising findings of the recent era.

The observation

Researchers training large ViTs noticed something strange: high-norm activations appearing at completely uninformative patches — chunks of sky, blurred backgrounds, low-content regions. The model was using these patches as scratch space, dumping unrelated global information there.

With no dedicated place to put global state, the network hijacks the least informative tokens — and corrupts them in the process. Attention maps look chaotic.

The fix — add register tokens

Append a small number (4–8) of extra trainable tokens to the input sequence:

Input:
Output: used for classification, registers discarded

The network now has dedicated scratch space. High-norm activations cluster on the register tokens rather than corrupting real patches. Attention maps become clean; dense-prediction tasks improve.

Slide line verbatim: *Registers = "garbage collector" tokens; they prevent attention peaks at blank areas by giving the model dedicated scratch space.*

Change #6 — Flash Attention

The single most consequential efficiency improvement to Transformers in the last 5 years. Flash Attention (Dao et al., NeurIPS 2022) is an exact attention algorithm — same outputs as standard attention, no approximation — that is dramatically faster and uses far less memory.

The bottleneck

Standard attention computes . The intermediate has shape — for , a 64M-entry matrix.

The bottleneck on modern GPUs isn't compute — it's memory traffic. HBM is huge but slow. SRAM (on-chip cache) is small but ~100× faster. Standard attention writes the full matrix to HBM, reads it back for softmax, writes again, reads for the value multiplication. Lots of slow memory traffic.

The trick — tile and never materialise

Flash Attention tiles the computation: process in small blocks that fit in SRAM, compute softmax block-by-block using online streaming, and *never write the full matrix to HBM*.

The mathematical core is the online softmax algorithm — you can compute softmax incrementally if you keep track of a running max and running denominator, rescaling at each step:

`` For each tile of K and V: Load tile into SRAM Compute partial attention scores Update running max m, running denominator l, running output O (rescale previous O when max changes) Discard the tile ``

The full attention matrix is never stored. Memory: . Wall-clock 2–4× faster.

Two exam emphases:

  • "Exact" — Flash Attention is not an approximation. Outputs are bitwise close to standard attention.
  • "IO-aware" — the speedup comes from minimising HBM ↔ SRAM data movement, not from doing less computation.

Change #7 — RoPE (Rotary Position Embedding)

For a token at position , embedding dim , with frequencies :

Rotated query/key: , . Critical property:

The attention score depends only on relative position.

Three advantages over learned absolute PEs:

  • Extrapolates to longer sequences than training. No fixed-size lookup table.
  • Encodes relative position directly. Depends only on .
  • Multiplicative, not additive. Rotation inside attention, not added to the input embedding.

Variants: 2D-RoPE (images), M-RoPE (video) — covered in the multimodal unit.

Change #8 — KV Caching

A *pure-inference* optimisation. When generating autoregressively:

Step 1: input = → predict .
Step 2: input = → predict .
Step 3: input = → predict .

Naïvely, at every step you re-compute attention over the entire history. work per token, for a generation of length .

KV-caching observes: and projections for past tokens never change. Once you compute , they're fixed. So store them in a cache, and at each new step:

1. Compute only for the new token. 2. Append the new to the cache. 3. Compute attention: new against all cached s and s.

Per-token work: . Total generation: . This is why chatbot streaming is fast.

Trade-off: memory. The cache grows linearly with sequence length, per layer, per head, per batch element. For long contexts, KV-cache memory dominates.

Change #9 — Grouped-Query Attention (GQA)

In MHA, you have query heads, key heads, value heads. KV-cache is proportional to . For a 70B-param LLM with 64 heads, the KV cache becomes enormous at long contexts.

Two extremes:

  • MHA separate heads. Maximum capacity, maximum cache.
  • MQA (Multi-Query) — all query heads share one and one head. Massive savings (factor of ), but quality loss.

GQA is the middle ground: divide the query heads into groups, where each group shares one head and one head.

query heads, KV groups
→ Each KV head shared by 8 query heads
→ KV cache is 8× smaller than full MHA

Performance with GQA is nearly as good as MHA. LLaMA-2, Gemma, Mistral, and most recent open LLMs use GQA.

Trade-off in one line: MHA = max quality, max cache. MQA = min cache, some quality loss. GQA = tunable middle ground via .

Putting it all together — the modern Transformer block

`` def transformer_block(x, kv_cache): # ---- Self-attention sublayer ---- h = RMSNorm(x) # Pre-Norm + RMSNorm Q = h @ W_Q; K = h @ W_K; V = h @ W_V # GQA: K, V have fewer heads Q = LayerNorm(Q); K = LayerNorm(K) # QK-Norm Q = apply_rope(Q, positions) # RoPE K = apply_rope(K, positions) K_cache, V_cache = kv_cache.append(K, V) # KV cache attn = flash_attention(Q, K_cache, V_cache) # Flash Attention x = x + diag(λ_1) @ attn # LayerScale # ---- MLP sublayer ---- h = RMSNorm(x) mlp = SwiGLU(h) x = x + diag(λ_2) @ mlp # LayerScale return x, kv_cache ``

For ViTs, add register tokens to the input sequence at the start, and discard them at the end. Use 2D-RoPE for image patches.

Every single line is different from the 2017 original.

What you carry into the exam

Pre-Norm vs Post-Norm equations and the residual-stream argument. LayerNorm vs RMSNorm equations and the geometric quiz answer (LayerNorm → anti-diagonal small circle; RMSNorm → full -sphere). LayerScale's init for near-identity start. QK-Norm placement and the softmax-saturation rationale. Registers as garbage-collector tokens with no PE. Flash Attention's exact-and-IO-aware nature, online softmax, memory. RoPE's rotation formula and the relative-position property. KV-cache asymptotics: generation. GQA's cache reduction, MHA/MQA/GQA spectrum. The modern Transformer block recipe.

That's the entire toolbox of *"what changed from 2017 to 2025."* Every modern LLM and ViT is a specific arrangement of these nine pieces.