Transformers:
From Code to Understanding

A complete walkthrough of how data flows through a transformer — from raw text to predictions during inference, and from predictions back to weight updates during training. Every tensor shape, every class, every line of code.

text tokenize embed + pos attention → FFN ×N layers project logits softmax sample next token append to sequence, repeat the autoregressive loop
Part I

Code → Inference Data Flow

The transformer is built from nested Python classes. Each class is an nn.Module — PyTorch's base class for anything that participates in the computation graph. The contract is simple: define __init__ (store your weights and sub-components) and forward (describe how data flows through).

The composition is what matters. These aren't separate things — they're nested like Russian dolls. When you call model(input_ids), it triggers a cascade: the Transformer calls each TransformerBlock, each TransformerBlock calls CausalSelfAttention and FeedForward, and CausalSelfAttention calls its four Linear layers internally.

THE CLASS HIERARCHY
Transformer
  ├── token_emb            (nn.Embedding)
  ├── pos_emb              (nn.Embedding)
  ├── blocks               (4 × TransformerBlock)
  │     └── TransformerBlock
  │           ├── ln1      (LayerNorm)
  │           ├── attn     (CausalSelfAttention)
  │           │     ├── W_q, W_k, W_v, W_out  (nn.Linear)
  │           ├── ln2      (LayerNorm)
  │           └── ffn      (FeedForward)
  │                 ├── up   (nn.Linear 64→256)
  │                 └── down (nn.Linear 256→64)
  ├── ln_f                 (LayerNorm)
  └── lm_head              (nn.Linear 64→vocab)

That tree is the architecture. Every leaf node contains learnable weights. Training updates all of them. Inference uses all of them. Let's trace a complete forward pass through the code.

Slide 1 of 11
Steps 1–2: Raw Text → Tokenize

Input: Token IDs

CODE — Transformer.forward()
model(input_ids)
Signatures & Types
CLASS model Transformer
VAR input_ids Tensor[1, 5] — int64, token indices
→ OUT model() → Tensor[1, 5, vocab_size]
Tensor Shape
[1, 5]
1 batch × 5 tokens

Integer token IDs enter the model. The vocabulary is sorted alphabetically: a=0, bed=1, big=2, cat=3, dog=4, house=5, mat=6, on=7, ran=8, rug=9, sat=10, slept=11, the=12, to=13. So "the cat sat on the" becomes [12, 3, 10, 7, 12].

Notice "the" appears twice and gets the same index 12 both times — it will receive the identical embedding vector at both positions. This is exactly why position embeddings exist, as we'll see next.

Slide 2 of 11
Steps 2–3: Embed → Positional Encoding

Embed + Position

CODE — Transformer.forward()
x = self.token_emb(idx) + self.pos_emb(torch.arange(T))
Signatures & Types
CLASS self.token_emb nn.Embedding(14, 64) — 14 vocab × 64 dims
CLASS self.pos_emb nn.Embedding(32, 64) — 32 positions × 64 dims
VAR idx Tensor[1, 5] — same as input_ids
VAR T int = 5 — sequence length
→ OUT arange(T) → Tensor[5] — [0, 1, 2, 3, 4]
VAR x Tensor[1, 5, 64] — token + position vectors
Tensor Shape
[1, 5, 64]
1 batch × 5 tokens × 64 dimensions

nn.Embedding is not a matrix multiply — it's a lookup table. When you pass it idx = [[12, 3, 10, 7, 12]], it says: "go to row 12, grab all 64 values. Go to row 3, grab all 64 values." Five lookups, each returning a 64-dimensional vector. Result: [1, 5, 64].

The position embedding is completely independent of the tokens. torch.arange(T) produces [0, 1, 2, 3, 4] — just integers. pos_emb looks up row 0, row 1, row 2, row 3, row 4 from a separate 32×64 table. It doesn't know or care what tokens are at those positions. Position 1 gets the same position vector whether the token there is "cat" or "dog."

The addition is element-wise — each of the 64 dimensions added independently:

CONCRETE EXAMPLE — "cat" at position 1
token_emb("cat"):    [+0.42, -1.08, +0.73, -0.15, +0.91, -0.56, ...]
pos_emb(1):          [+0.11, +0.85, -0.32, +0.67, -0.04, +0.28, ...]
───────────────────────────────────────────────────────────────────────
x (result):          [+0.53, -0.23, +0.41, +0.52, +0.87, -0.28, ...]

That's all the "addition" is. dim 0: 0.42 + 0.11 = 0.53. The resulting vector now encodes both "I am cat" and "I am at position 1" simultaneously, baked into the same 64 numbers. The rest of the model never sees these two pieces separately — it only works with this combined vector. From here on, everything is vectors.

Both embedding tables must have the same second dimension — you can't add a 64-dim vector to a 128-dim vector. That's what d_model controls: the width of the highway that data flows through from embedding to logits. In a real model like Llama 7B, d_model is 4096.

Slide 3 of 11
Step 8: ×N Layers

Enter TransformerBlock

CODE — Transformer.forward()
for block in self.blocks:
    x = block(x)
Signatures & Types
CLASS self.blocks ModuleList[4 × TransformerBlock]
CLASS block TransformerBlock — different instance each iteration
VAR x (in) Tensor[1, 5, 64]
→ OUT block(x) → Tensor[1, 5, 64] — same shape, different values
Tensor Shape
[1, 5, 64] → [1, 5, 64] → [1, 5, 64] → [1, 5, 64] → [1, 5, 64]
shape unchanged through all 4 blocks — same in, same out

This creates 4 separate TransformerBlock instances. The code is identical — same class, same forward method, same structure. The weights are different. When each instance was created, nn.Linear(64, 64) initialized fresh random weights. Block 0's W_q is a completely different 64×64 matrix from Block 1's W_q.

The output of Block 0 is fed directly as input to Block 1. The variable x is overwritten each iteration. It's like a relay race: each runner (block) has different strengths (weights), but each one picks up where the previous one left off — the baton is x.

THE LOOP UNROLLED — tracking "the" (position 4)
x = token_emb + pos_emb        → "the" — just embedded, knows nothing
x = block_0(x)                  → "the" after "on" — expects a noun
x = block_1(x)                  → "sat on the ___" — location context
x = block_2(x)                  → cat + sit + surface → mat, rug, bed...
x = block_3(x)                  → "mat" 32%, "floor" 18%, "rug" 12%

If all 4 blocks somehow had the same weights (shared parameters), the model would still work but would be much weaker — you'd effectively have one block applied 4 times, which limits what it can learn. The power comes from each block having independent weights that specialize during training for their specific position in the chain.

Slide 4 of 11
Steps 4 + 7: Self-Attention + Residual

LayerNorm → Attention

CODE — TransformerBlock.forward()
x = x + self.attn(self.ln1(x))
Signatures & Types
CLASS self.ln1 nn.LayerNorm(64) — normalizes to mean=0, std=1
CLASS self.attn CausalSelfAttention — contains W_q, W_k, W_v, W_out
→ OUT self.ln1(x) → Tensor[1, 5, 64] — normalized, same shape
→ OUT self.attn(...) → Tensor[1, 5, 64] — attention output
VAR x + ... Tensor[1, 5, 64] — residual: input + attention output

One line, three operations, reading right to left.

self.ln1(x) — LayerNorm normalizes each token's 64-dim vector to have mean 0 and standard deviation 1. Without it, the numbers in x can drift to wildly different scales as they pass through blocks. LayerNorm is nonlinear — it divides by the standard deviation, which is computed from the input itself. If you doubled every value, the std doubles, but the output stays the same. That violates linearity.

self.attn(...) — this calls CausalSelfAttention.forward, which is everything in slides 5, 6, and 7: project to Q, K, V, compute attention scores, mask the future, softmax, blend values, project output. All of that happens inside this one call.

x + ... — the residual connection. This is the most important part. The attention output is added to the original x, not replacing it. If attention produces garbage (which happens early in training), the original information survives. The block can only nudge x, never erase it.

Key Insight — Residual Connections

Think of it concretely. Before this line, position 4 has a vector meaning "the, at position 4." After this line, that vector still contains "the, at position 4" plus whatever attention contributed — maybe "preceded by on, following a sitting action." The original is preserved; new information is layered on top.

Slide 5 of 11
Step 4: Self-Attention (Q·K·V)

Q, K, V Projections

CODE — CausalSelfAttention.forward()
Q = self.W_q(x).view(B, T, 4, 16).transpose(1, 2)
K = self.W_k(x).view(B, T, 4, 16).transpose(1, 2)
V = self.W_v(x).view(B, T, 4, 16).transpose(1, 2)
Signatures & Types
CLASS self.W_q nn.Linear(64, 64) — 64×64 weight matrix + bias
CLASS self.W_k nn.Linear(64, 64) — separate 64×64 matrix
CLASS self.W_v nn.Linear(64, 64) — third separate 64×64 matrix
→ OUT self.W_q(x) → Tensor[1, 5, 64] — matrix multiply: x @ W_q
→ OUT .view(1,5,4,16) → Tensor[1, 5, 4, 16] — reshape 64 → 4 heads × 16
→ OUT .transpose(1,2) → Tensor[1, 4, 5, 16] — move heads before tokens
VAR Q, K, V 3× Tensor[1, 4, 5, 16]
Tensor Shape
[1, 5, 64] → 3× [1, 4, 5, 16]
one input → three outputs, split into 4 heads of 16 dimensions each

Three operations chained on each line. Two of the three are just reshuffling dimensions — the only actual computation is the matrix multiply with W_q (or W_k, or W_v).

self.W_q(x)nn.Linear(64, 64) multiplies each token's 64-dim vector by the W_q weight matrix. Each token is projected independently by the same matrix. Output: [1, 5, 64]. Same shape, but the 64 numbers have been rotated and mixed into "query space."

.view(B, T, 4, 16) — a reshape, not a computation. No numbers change. It reinterprets the 64 numbers as 4 groups of 16. Head 0 gets dims 0–15, head 1 gets dims 16–31, and so on. This is how multi-head attention splits the work.

.transpose(1, 2) — swaps the token and head dimensions: [1, 5, 4, 16][1, 4, 5, 16]. This rearrangement is needed because the next operation — Q @ K.transpose(-2, -1) — does matrix multiplication on the last two dimensions. We need those to be [tokens, d_k] so the result is [tokens, tokens].

The 4 and 16 are n_heads and d_k respectively, defined as d_k = d_model // n_heads = 64 // 4 = 16. More heads means more independent attention patterns but each with a narrower view.

Slide 6 of 11
Step 4: Self-Attention (scores)

Attention Scores

CODE — CausalSelfAttention.forward()
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
Signatures & Types
VAR Q @ K.T Tensor[1,4,5,16] @ Tensor[1,4,16,5] — (5×16)@(16×5)=5×5
VAR scores Tensor[1, 4, 5, 5] — raw attention scores
VAR / √d_k / float(4.0) — √16 = 4.0, scaling factor
VAR mask Tensor[5, 5] bool — upper triangle = True (future)
→ OUT F.softmax() → Tensor[1, 4, 5, 5] — each row sums to 1.0
VAR attn Tensor[1, 4, 5, 5] — attention weights (probabilities)
Tensor Shape
[1, 4, 5, 16] @ [1, 4, 16, 5] → [1, 4, 5, 5]
5×5 attention matrix per head — "how much should token i attend to token j?"

Q @ K.transpose(-2, -1) — the @ is Python's matrix multiplication operator. K.transpose(-2, -1) swaps the last two dimensions of K. The batch and head dimensions are along for the ride — 4 independent 5×5 score matrices, one per head. Each entry (i, j) is the dot product of token i's Query with token j's Key: "how relevant is token j to what token i is looking for?"

/ math.sqrt(self.d_k) — divide every score by √16 = 4.0. Without this, dot products of 16-dimensional vectors can produce numbers in the range ±16. Softmax is extremely sensitive to scale — large scores push almost all weight onto the maximum value. Dividing by √d_k keeps the distribution smooth, allowing the model to attend to multiple tokens simultaneously. This is the "scaled" in "scaled dot-product attention."

scores.masked_fill(mask, float("-inf")) — the causal mask. Every position above the diagonal gets filled with negative infinity. After softmax, e−∞ = 0, so those positions contribute zero attention weight. Token 2 ("sat") can attend to tokens 0, 1, 2 but not 3 or 4. This is what makes it autoregressive — no cheating by looking ahead.

F.softmax(scores, dim=-1) — converts each row to probabilities summing to 1.0. Before softmax: [0.82, 1.45, 0.31, -inf, -inf]. After: [0.24, 0.45, 0.14, 0.00, 0.00]. Each head has its own completely different attention pattern — learned during training.

Slide 7 of 11
Step 4: Self-Attention (output)

Weighted Values → Output

CODE — CausalSelfAttention.forward()
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(out)
Signatures & Types
VAR attn @ V Tensor[1,4,5,5] @ Tensor[1,4,5,16] — (5×5)@(5×16) = 5×16 per head
→ OUT (attn @ V) → Tensor[1, 4, 5, 16] — blended values
→ OUT .transpose().view() → Tensor[1, 5, 64] — 4 heads × 16 → 64 dims
CLASS self.W_out nn.Linear(64, 64) — mixes across heads
→ OUT return → Tensor[1, 5, 64] — back to TransformerBlock
Tensor Shape
[1, 4, 5, 5] @ [1, 4, 5, 16] → [1, 5, 64]
back to original shape, but enriched with information from other tokens

attn @ V — the payoff of all the work so far. For each token, it blends all tokens' Value vectors using the attention weights as the recipe. For token 4 ("the") in head 0, if the attention weights are [0.15, 0.10, 0.45, 0.20, 0.10], then:

THE BLEND
output[4] = 0.15 × V("the"₀) 
          + 0.10 × V("cat") 
          + 0.45 × V("sat") 
          + 0.20 × V("on") 
          + 0.10 × V("the"₄)

A 16-dim vector that is mostly "sat"'s value content (45%) blended with "on" (20%). The Q/K computation decided who to attend to. The V vectors decide what information to pass along.

.transpose(1, 2).contiguous().view(B, T, C) — the 4 heads' outputs are concatenated back into one 64-dim vector per token. .contiguous() is a memory bookkeeping operation — transpose doesn't physically rearrange data in memory, just changes how indices map to addresses. .view() needs contiguous memory to reshape.

self.W_out(out) — one final nn.Linear(64, 64). This mixes information across heads. Up to this point, each head operated independently. W_out can combine insights from different heads: head 0 found the verb, head 2 found the subject — W_out blends them into a unified representation.

Slide 8 of 11
Steps 5 + 7: Feed-Forward + Residual

LayerNorm → FFN

CODE — TransformerBlock.forward() + FeedForward.forward()
x = x + self.ffn(self.ln2(x))

# Inside FeedForward.forward(x):
return self.down(F.gelu(self.up(x)))
Signatures & Types
CLASS self.ffn.up nn.Linear(64, 256) — expand: 64 → 4×64 = 256
→ OUT F.gelu() → Tensor[1, 5, 256] — activation (element-wise)
CLASS self.ffn.down nn.Linear(256, 64) — compress: 256 → 64
VAR x + ... Tensor[1, 5, 64] — residual: input + FFN output
Tensor Shape
[1, 5, 64] → [1, 5, 256] → [1, 5, 64]
expand to 4×, activate, compress back

self.up(x) — each token's 64-dim vector is multiplied by a 64×256 weight matrix, producing a 256-dim vector. The expansion gives the model a larger space where it can represent more complex features and combinations.

F.gelu(...) — an activation function applied element-wise. Positive values pass through mostly unchanged, negative values get pushed toward zero. Without this nonlinearity, down(up(x)) would be two matrix multiplies in a row — mathematically equivalent to a single matrix multiply. GELU breaks that equivalence. The expand-activate-compress sequence can now represent functions that a single linear transform cannot.

self.down(...) — compress from 256 back to 64. The model did its work in the expanded space and distilled the result back to d_model dimensions.

The critical difference from attention: no cross-token communication. The FFN processes each token independently — the same up and down matrices applied to every token identically. Attention figured out what the context is. The FFN applies learned knowledge: "given a sitting-on-surface context, surfaces include mat, rug, floor." This is the "knowledge store" — roughly two thirds of all parameters live in the FFN layers.

Slide 9 of 11
Step 8: ×N Layers (repeat)

Block Output → Next Block

CODE — TransformerBlock.forward()
return x  # → next iteration of the loop

x entered the block carrying whatever previous blocks contributed. It now leaves with two residual additions baked in — attention and FFN. The returned tensor goes back to the loop in Transformer.forward, where x = block(x) overwrites the variable, and the next block begins.

One TransformerBlock.forward is really: "let tokens talk to each other (attention), add that to what we had, then let each token think independently (FFN), add that to what we had." Two residual additions per block, 4 blocks, means 8 cumulative additions — the final x is the original embedding plus 8 layers of learned adjustments stacked on top.

Slide 10 of 11
Step 9: Logits → Sample

Final → Logits

CODE — Transformer.forward()
return self.lm_head(self.ln_f(x))
Signatures & Types
CLASS self.ln_f nn.LayerNorm(64) — final normalization
CLASS self.lm_head nn.Linear(64, 14) — 64 → vocab_size (no bias)
→ OUT return (logits) → Tensor[1, 5, 14] — one score per word
Tensor Shape
[1, 5, 64] → [1, 5, 14]
64 dims → 14 vocabulary scores

self.ln_f(x) — one last LayerNorm. After 4 blocks of residual additions, the numbers in x could have drifted to various scales. Final normalization puts them in a clean range before the output projection.

self.lm_head(x)nn.Linear(64, 14). Each token's 64-dim vector is multiplied by a 64×14 weight matrix, producing 14 numbers — one score per word in the vocabulary. These raw scores are the logits. They're not probabilities — unbounded, can be negative, don't sum to 1.

Each position produces its own logits, but during inference we only care about the last position — that's where the next token prediction comes from. Position 4 has seen all 5 tokens via attention, so its 64-dim vector encodes the full context. softmax(logits[0, -1, :]) converts the last position's 14 scores to probabilities, and we sample from that distribution.

The Complete Forward Pass
[1, 5]           integer token IDs
[1, 5, 64]       embedded + position
[1, 5, 64]       block 0: attention + FFN
[1, 5, 64]       block 1: attention + FFN
[1, 5, 64]       block 2: attention + FFN
[1, 5, 64]       block 3: attention + FFN
[1, 5, 14]       logits — one score per vocabulary word

5 integers in, 14 scores out. Everything in between is learned vector transformations — projections between spaces. The entire forward pass is a pipeline of geometric transformations gradually reshaping vector clouds until the desired prediction becomes easy to read off the final layer.

Slide 11 of 11
Step 10: The Generation Loop

Append → Repeat: The Autoregressive Loop

CODE — generate_text()
for _ in range(max_tokens):
    logits = model(input_tensor)                    # full forward pass
    last_logits = logits[0, -1, :] / temperature    # last position only
    probs = F.softmax(last_logits, dim=-1)          # → probabilities
    next_id = torch.multinomial(probs, 1).item()    # sample one token
    input_tensor = torch.cat(                       # APPEND to sequence
        [input_tensor, torch.tensor([[next_id]])], dim=1)
Signatures & Types — each iteration
VAR input_tensor (iter 0) Tensor[1, 5] — "the cat sat on the"
→ OUT model(input_tensor) → Tensor[1, 5, 14]
→ OUT logits[0, -1, :] → Tensor[14] — position 4's scores only
→ OUT F.softmax() → Tensor[14] — [mat: 35%, rug: 12%, ...]
→ OUT torch.multinomial() → int = 6 — sampled "mat"
VAR input_tensor (iter 1) Tensor[1, 6] — "the cat sat on the mat"
→ OUT model(input_tensor) → Tensor[1, 6, 14] — now 6 positions
→ OUT logits[0, -1, :] → Tensor[14] — position 5's scores
→ OUT torch.multinomial() → int = 12 — sampled "the"
VAR input_tensor (iter 2) Tensor[1, 7] — "the cat sat on the mat the"
Tensor Shape — growing each iteration
[1, 5] → [1, 6] → [1, 7] → [1, 8] → ...
sequence grows by 1 token each iteration — forward pass re-runs on the full sequence

This is the ouroboros — the loop that closes the circle. The forward pass produces logits. We take only the last position's logits (that's where the prediction for the next token lives), convert to probabilities via softmax, and sample one token. Then the critical step: torch.cat appends that new token to the input sequence, making it one token longer. And we run the entire forward pass again.

Each iteration, the sequence grows by one. The forward pass gets slightly more expensive — one more token to embed, one more row and column in every attention matrix, one more token through every FFN. The model sees the full history every time.

THE LOOP UNROLLED
iter 0: model(["the","cat","sat","on","the"])           → sample "mat"
iter 1: model(["the","cat","sat","on","the","mat"])      → sample "the"
iter 2: model(["the","cat","sat","on","the","mat","the"]) → sample "cat"
...
Each iteration: fresh forward pass on the FULL growing sequence.
The causal mask ensures earlier positions produce identical logits
as before — they can't see the new tokens. But the new token
can attend to everything before it.

This is where the KV cache optimization matters. Without it, every iteration recomputes K and V for all previous tokens — work that was already done. With caching, only the new token's K and V are computed and appended to the cache. The new token's Q attends to all cached K's. The speedup grows linearly with sequence length.

The temperature parameter controls the sharpness of sampling. Division happens before softmax: logits / temperature. Temperature 0.5 makes the distribution sharper (more deterministic, always picks the most likely). Temperature 1.5 flattens it (more creative, more random). Temperature 1.0 uses the raw probabilities as-is.

The loop continues for max_tokens iterations or until the model produces a special end-of-sequence token. Every response Claude generates — every sentence, every paragraph — was produced by this exact loop: forward pass, sample, append, repeat. One token at a time, each conditioned on everything that came before.

The Ouroboros

The autoregressive loop is the snake eating its tail. The model's output becomes its own input. Each new token changes the context, which changes the next prediction, which changes the next token. A deterministic architecture producing varied, context-dependent text — all from one simple rule: predict the next token, append it, repeat.

Part II

Code → Training Data Flow

Training wraps the forward pass in a learning loop. The forward pass itself is just inference — the same Transformer.forward() runs. The difference is what happens after: measure the error, compute gradients, and nudge every weight. That cycle, repeated hundreds or thousands of times, transforms random noise into a model that understands language.

THE TRAINING PIPELINE
# Setup (once)
vocab = build_vocab(CORPUS)                    # → 24 unique words
tokenizer = SimpleTokenizer(vocab)              # word ↔ integer lookup
model = Transformer(vocab_size=24, ...)         # random weights
inputs, targets = make_training_pairs(corpus)   # shifted by 1
optimizer = Adam(model.parameters(), lr=0.003)  # tracks all weights

# Loop (750 times)
for epoch in range(150):
    batches = collate_batch(inputs, targets, batch_size=8)
    for input_ids, target_ids in batches:
        logits = model(input_ids)               # forward pass
        loss = F.cross_entropy(logits, targets)  # measure error
        optimizer.zero_grad()                    # clear old gradients
        loss.backward()                          # compute gradients
        optimizer.step()                         # update all weights
Slide 1 of 10
Step 1: Training Data

Training Corpus

CODE
CORPUS = [
    "the cat sat on the mat",
    "the dog sat on the rug",
    "the cat slept on the bed",
    ...  # 20 sentences total
]
vocab = build_vocab(CORPUS)
Signatures & Types
VAR CORPUS list[str] — 20 sentences
→ OUT build_vocab() → list[str] — sorted unique words
VAR vocab list[str], len=24 — ["a","bed","big","bird","brown","cat",...]

Everything starts with text. 20 sentences, extract all unique words, sort them alphabetically. The result is a vocabulary of 24 strings, each getting an integer index. This is the entire world the model will learn from. Real training uses trillions of tokens, but the process is identical: text in, vocabulary out.

GPT-style models use subword tokenization (BPE — byte pair encoding) with ~100k tokens rather than whole words. The algorithm starts with individual characters and iteratively merges the most frequent adjacent pairs. "Unbelievable" becomes three tokens: ["un", "believ", "able"]. This handles any text, even words never seen before. For our learning purposes, word-level tokenization keeps things concrete.

Slide 2 of 10
Step 2: Input → Target Pairs

Input → Target Pairs

CODE
inputs, targets = make_training_pairs(corpus, tokenizer)
Signatures & Types
CLASS tokenizer SimpleTokenizer — word ↔ integer lookup
→ OUT tokenizer.encode() → list[int] — [12,3,10,7,12,6]
VAR inputs[0] list[int] — [12,3,10,7,12] = "the cat sat on the"
VAR targets[0] list[int] — [3,10,7,12,6] = "cat sat on the mat"

This is where the training signal comes from. Take "the cat sat on the mat" and split it:

THE SHIFT
sentence: the    cat    sat    on    the    mat
          ─────  ─────  ─────  ─────  ─────  ─────
input:    the    cat    sat    on     the
target:   cat    sat    on     the    mat
          ─────  ─────  ─────  ─────  ─────
pos 0:    "the" should predict "cat"
pos 1:    "the cat" should predict "sat"
pos 2:    "the cat sat" should predict "on"
pos 3:    "the cat sat on" should predict "the"
pos 4:    "the cat sat on the" should predict "mat"

Target is the input shifted by one position. Every position is an independent prediction problem: position 0 sees "the" and should predict "cat", position 4 sees "the cat sat on the" and should predict "mat". All 5 are computed in a single forward pass — that's the density of next-token prediction.

Slide 3 of 10
Step 2: Batching

Collate into Batches

CODE
batches = collate_batch(inputs, targets, batch_size=8)
Signatures & Types
VAR batch_size int = 8 — sequences per batch
→ OUT torch.tensor(padded_inp) → Tensor[8, T] — T = max seq len in batch
VAR batches[0] (Tensor[8, 5], Tensor[8, 5]) — (input_ids, target_ids)
Data Transformation
list[list[int]] → list[(Tensor[B, T], Tensor[B, T])]
variable-length lists → fixed-shape tensor pairs, shuffled each epoch

The ~40 training pairs are shuffled and grouped into batches of 8. Each batch is two tensors: input_ids: Tensor[8, T] and target_ids: Tensor[8, T]. Shorter sequences are padded to equal length.

The B sequences in a batch are independent — they share the model weights but cannot attend to each other. Sequence 0 has its own attention matrix, its own causal mask. Sequence 7 has completely different ones. B is just parallelism — one big matrix multiply is faster than 8 small ones on a GPU.

Shuffling each epoch forces the model to generalize. If it always saw the same sequences grouped together, it could learn shortcuts specific to that ordering.

Slide 4 of 10
Step 3: Forward Pass

Forward Pass

CODE — train_step()
model.train()
logits = model(input_ids)
Signatures & Types
CLASS model Transformer — 33,000+ learnable parameters
VAR input_ids Tensor[8, 5] — 8 sequences × 5 tokens
→ OUT model(input_ids) → Tensor[8, 5, 24] — 8 seqs × 5 positions × 24 vocab scores
Tensor Shape
Tensor[8, 5] → Tensor[8, 5, 24]
input integers → one score per vocab word, at every position

This is the exact same forward pass we traced through all 10 slides of Part I. Embed, add position, pass through 4 transformer blocks, project to vocabulary. The code is identical. The architecture is identical.

The critical difference from inference: we keep all positions' logits, not just the last one. During inference, we only care about the last position's prediction. During training, every position is a learning opportunity. One 5-token sequence gives us 5 separate training signals. One batch of 8 sequences gives us 40.

During this forward pass, PyTorch is quietly recording every operation in a computation graph. Every matrix multiply, every addition, every softmax — all recorded. This graph is what loss.backward() will traverse in the backpropagation step.

Slide 5 of 10
Steps 3–4: Reshape

Reshape for Loss

CODE — train_step()
B, T, V = logits.shape           # 8, 5, 24
logits_flat = logits.view(B*T, V)  # [40, 24]
targets_flat = target_ids.view(B*T) # [40]
Signatures & Types
VAR B int = 8 — batch size
VAR T int = 5 — sequence length
VAR V int = 24 — vocabulary size
→ OUT logits.view(B*T, V) → Tensor[40, 24] — 40 predictions, 24 scores each
→ OUT target_ids.view(B*T) → Tensor[40] — 40 correct answers (integers)

Pure bookkeeping — no math. We have 40 rows of 24 numbers (the model's predictions at each position) and 40 integers (what actually came next — the ground truth). .view() flattens the batch and sequence dimensions: 8 sequences × 5 positions = 40 individual prediction problems. PyTorch's F.cross_entropy expects this flat layout.

Slide 6 of 10
Step 4: Cross-Entropy Loss

Cross-Entropy Loss

CODE — train_step()
loss = F.cross_entropy(
    logits.view(B * T, V),     # [40, 24] — predictions
    target_ids.view(B * T),    # [40] — correct answers
)
Signatures & Types
CLASS F.cross_entropy function — PyTorch loss function
VAR logits (flat) Tensor[40, 24] — 40 predictions
VAR targets (flat) Tensor[40] — 40 correct answers
VAR loss Tensor[] (scalar) — single number, e.g. 3.18
Data Transformation
Tensor[40, 24] + Tensor[40] → Tensor[] (scalar)
40 predictions + 40 targets → 1 number

This function does three things internally, for each of the 40 predictions:

Softmax — converts 24 raw scores to probabilities summing to 1.0. Large scores get high probabilities via exponentiation.

Lookup — finds the probability the model assigned to the correct token. If the target is "on" (ID 7) and the model assigned 6% to it, that's the number we care about.

Negative logloss = −log(probability). Small probability → large loss. 90% → 0.105 loss. 6% → 2.81 loss. 1% → 4.6 loss. The log curve punishes low confidence much more harshly than it rewards high confidence, giving the strongest learning signal where the model needs the most improvement.

Cross-entropy is just the measuring stick — it says "you're this far off." It doesn't nudge anything. Backprop figures out which weights are responsible. The optimizer does the actual nudging.

At the start of training, random weights assign roughly equal probability to all 24 words — about 4% each. Expected initial loss: −log(1/24) ≈ 3.18.

Slide 7 of 10
Step 5: Backpropagation

Backpropagation

CODE — train_step()
optimizer.zero_grad()   # clear old gradients
loss.backward()         # THE chain rule
Signatures & Types
CLASS optimizer torch.optim.Adam — tracks all model parameters
VAR loss Tensor[] (scalar) — connected to full computation graph
→ OUT loss.backward() → None — fills .grad for all 33k+ params
VAR W_q.weight.grad Tensor[64, 64] — gradient for every W_q entry
VAR ffn.up.weight.grad Tensor[256, 64] — gradient for every FFN entry
VAR token_emb.weight.grad Tensor[24, 64] — gradient for every embedding
Data Transformation
Tensor[] (scalar) → .grad on all 33,000+ parameters
1 number flows backward to fill 33,000+ gradient values

loss.backward() traces backward through every operation that produced the loss, applying the chain rule at every step. The backward pass follows the computation graph in reverse:

THE BACKWARD TRACE
loss
  ← cross_entropy ← lm_head ← ln_f
    ← Block 3:
      ← residual: gradient splits → FFN path + straight-through
        ← down ← gelu ← up
      ← residual: gradient splits → attention path + straight-through
        ← W_out ← attn@V ← softmax ← Q@K^T
          ← W_q (via Q path)
          ← W_k (via K path)
          ← W_v (via V path)
    ← Block 2 ← Block 1 ← Block 0
      ← token_emb, pos_emb

Notice the residual connections. At each x = x + self.attn(...), the gradient splits: one copy flows through the attention path, another passes straight through the addition unchanged. That straight-through path is crucial — it provides a gradient highway from the loss directly to early layers, combating the vanishing gradient problem.

This is also where W_q, W_k, and W_v get different gradients despite sharing the same input. W_q's gradient comes through the "left side of the dot product" path. W_k's comes through the "right side." W_v's comes through the "weighted sum" path. Different paths → different gradients → different updates → different specializations.

loss.backward() is computationally about 2× the forward pass. Training is roughly 3× inference per token: 1× forward, 2× backward.

Slide 8 of 10
Step 6: Gradient Descent

Gradient Descent

CODE — train_step()
optimizer.step()
Signatures & Types
→ OUT optimizer.step() → None — updates all parameters in-place
VAR p.data (before) Tensor[64, 64] — e.g. W_q = +0.0372
VAR p.grad Tensor[64, 64] — gradient = +0.0031
VAR lr float = 3e-3 — learning rate (0.003)
VAR p.data (after) Tensor[64, 64] — W_q ≈ +0.0362 (nudged)

Every weight in the model is updated simultaneously. The basic principle: new_weight = old_weight − learning_rate × gradient. Positive gradient means the weight is pushing the loss up — decrease it. Negative means it's helping — increase it.

We use Adam optimizer, which solves three problems plain gradient descent suffers from:

Momentum — a running average of recent gradients. Consistent direction → amplified step. Oscillating direction → dampened step. Smooths out the noise from small batches.

Adaptive learning rate — tracks squared gradients per parameter. Weights with large gradients get smaller effective steps; weights with tiny gradients get larger steps. Every single weight gets its own automatically tuned step size.

Bias correction — compensates for the zero initialization of momentum and variance estimates in the first few steps.

Adam maintains two extra numbers per parameter (momentum and variance), so it uses 3× the memory of the weights alone. For a 7B parameter model: 7B weights + 7B momentum + 7B variance = 21B numbers in GPU memory.

One update is imperceptible — each weight changes by ~0.001. But optimizer.step() will be called 750 times, each time nudging all 33,000 weights. Total: ~24.75 million individual weight nudges.

Slide 9 of 10
Step 7: Training Loop

The Training Loop

CODE — train()
for epoch in range(150):
    batches = collate_batch(inputs, targets, 8)
    for input_ids, target_ids in batches:
        loss = train_step(model, optimizer,
                          input_ids, target_ids)

Four lines of code, but they represent the bulk of compute in all of deep learning. GPT-4's training reportedly cost over $100 million — nearly all spent running this loop.

Epoch 1: loss ≈ 3.18 — random guessing. −log(1/24).

Epoch 5: loss ≈ 2.61 — the model learns word frequencies. "The" appears far more often than "sun." Going from uniform 4% to frequency-weighted predictions is a big improvement.

Epoch 25: loss ≈ 1.44 — bigrams emerge. "On" is usually followed by "the." The attention mechanism is learning to look at the immediately preceding token.

Epoch 50: loss ≈ 0.82 — longer-range patterns. "Sat on the" → surface. "Ran to the" → destination. The deeper blocks contribute now, building on features from earlier blocks.

Epoch 100: loss ≈ 0.41 — most predictions right. Remaining loss is genuine ambiguity — after "the cat sat on the," both "mat" and "rug" appeared in the corpus.

Epoch 150: loss ≈ 0.28 — nearly converged. The remaining loss is irreducible uncertainty — the entropy of the data itself.

The loss curve is steep at first (easy patterns), then gradually flattens (diminishing returns). This shape is universal across all neural network training. Monitoring it is the primary way practitioners diagnose training.

Slide 10 of 10
Step 8: Before vs After

Before vs After

CODE
model.eval()                 # disable training-only behavior
with torch.no_grad():        # skip gradient tracking
    show_predictions(model, tokenizer, "the cat sat on")
    generate_text(model, tokenizer, "the cat")
Signatures & Types
→ OUT model.eval() → None — disables dropout etc.
CLASS torch.no_grad() context manager — skip computation graph (faster)
VAR logits[0, -1, :] Tensor[24] — last position's 24 scores
→ OUT F.softmax(...) → Tensor[24] — probabilities
→ OUT torch.multinomial(probs, 1) → Tensor[1] — sample one token

The forward pass is identical to inference — same code, same math. The only difference: the weights. Before training, W_q was random noise. After 750 rounds of gradient descent, it's been sculpted into a transformation that produces meaningful queries.

BEFORE TRAINING — "the cat sat on" → next token
mat: 4.2%   rug: 3.9%   the: 3.8%   dog: 4.5%   ← random noise
AFTER TRAINING — "the cat sat on" → next token
the: 45%    mat: 12%    rug: 8%     bed: 5%     ← learned patterns
The Fundamental Insight

The architecture is the plumbing — the fixed structure of embeddings, attention, FFN, residual connections, and the output projection. Training is what fills that plumbing with understanding. Every matrix that was random noise is now a carefully shaped transformation that collectively implements "given these words, predict what comes next."

The model didn't learn explicit rules. It learned 33,000 numbers that, when applied in sequence through the transformer pipeline, happen to produce good predictions. The architecture didn't change. The training procedure didn't change. What changed is the data — and the data determines what the model learns.

It all follows the same pattern: functional composition. text → [tokenize] → ids → [embed] → vectors → [transform] → ... → predictions. The same pattern as the computation graph of derivatives, the same pattern as Church-Turing functional composition. Data flows through a chain of transformations, each one reshaping the geometry of the vector space until the desired answer becomes easy to read off.

End where we began
text tokenize embed + pos attention → FFN ×N layers project logits softmax sample next token append to sequence, repeat the autoregressive loop

Built through conversation — one question at a time, one token at a time.

Code & Interactive Visualizations

All source code is embedded below — click to expand, copy to clipboard. Python files are runnable: pip install torch and execute.

transformer_from_scratch.py The complete transformer — inference architecture
Python 505 lines
"""
A Minimal Transformer — From Text to Next Token

This implements every step from our visualization in real PyTorch code.
Not production code — intentionally simple and heavily commented so you
can map each function to the conceptual step.

Steps:
  1. Tokenize    — text → integer IDs
  2. Embed       — integer IDs → learned vectors
  3. Positional  — add position information
  4. Attention   — Q·Kᵀ/√d → softmax → ·V
  5. FFN         — expand → activate → compress
  6. Residual    — x + f(x), then normalize
  7. Stack       — repeat for N layers
  8. Logits      — project back to vocabulary
  9. Sample      — probability distribution → pick one token
 10. Loop        — feed it back, repeat

Run this file directly to see a full forward pass with prints at every stage:
    python transformer_from_scratch.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# =============================================================================
# STEP 1: TOKENIZE
# =============================================================================
# In production, this would be BPE (byte-pair encoding) — a learned algorithm
# that splits text into subword chunks. Here we use a dead-simple word-level
# tokenizer so you can see exactly what's happening.

class SimpleTokenizer:
    """
    Maps words to integers and back. That's ALL tokenization is.
    No neural network, no learning — just a lookup table.
    """
    def __init__(self, vocab: list[str]):
        self.word_to_id = {word: i for i, word in enumerate(vocab)}
        self.id_to_word = {i: word for i, word in enumerate(vocab)}
        self.vocab_size = len(vocab)

    def encode(self, text: str) -> list[int]:
        """Text in, integers out. Purely mechanical."""
        return [self.word_to_id[w] for w in text.lower().split()]

    def decode(self, ids: list[int]) -> str:
        """Integers in, text out."""
        return " ".join(self.id_to_word[i] for i in ids)


# =============================================================================
# STEP 2 & 3: EMBED + POSITIONAL ENCODING
# =============================================================================

class Embedding(nn.Module):
    """
    Step 2: Token embedding — each integer ID looks up a ROW in a learned matrix.
            Token 4521 → row 4521 → a d_model-dimensional vector.
            This is where MEANING enters. These vectors are learned during training
            so that semantically similar tokens end up near each other.

    Step 3: Positional encoding — without this, "dog bites man" and "man bites dog"
            would produce identical representations (addition is commutative).
            We add a unique position signal so the model knows token ORDER.
    """
    def __init__(self, vocab_size: int, d_model: int, max_seq_len: int):
        super().__init__()
        # Step 2: The embedding matrix — shape [vocab_size, d_model]
        # Each of the ~100k vocabulary entries gets its own learned vector
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Step 3: Positional encoding — also a learned lookup table
        # Position 0 gets one vector, position 1 gets another, etc.
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: shape [batch_size, seq_len] — the integers from tokenization
        returns:   shape [batch_size, seq_len, d_model] — rich vectors with position info
        """
        seq_len = token_ids.shape[1]

        # Step 2: Look up token vectors
        tok_emb = self.token_embedding(token_ids)  # [batch, seq_len, d_model]

        # Step 3: Look up position vectors
        positions = torch.arange(seq_len, device=token_ids.device)  # [0, 1, 2, ...]
        pos_emb = self.position_embedding(positions)  # [seq_len, d_model]

        # Add them together — the model gets BOTH "what token" and "what position"
        return tok_emb + pos_emb


# =============================================================================
# STEP 4: SELF-ATTENTION
# =============================================================================

class CausalSelfAttention(nn.Module):
    """
    The core mechanism: each token decides what to "pay attention to"
    in the sequence so far.

    Three learned projections turn each token's vector into:
      Q (query)  — "what am I looking for?"
      K (key)    — "what do I contain?"
      V (value)  — "what information do I provide if attended to?"

    Then: scores = Q · Kᵀ / √d_k  →  softmax  →  weighted sum of V

    Multi-head: we do this multiple times in parallel with different
    learned projections, then concatenate. Each "head" can learn to
    attend to different things (syntax, coreference, semantics...).
    """
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # dimension per head

        # Three separate learned linear projections
        self.W_q = nn.Linear(d_model, d_model)  # projects to all Q heads at once
        self.W_k = nn.Linear(d_model, d_model)  # projects to all K heads at once
        self.W_v = nn.Linear(d_model, d_model)  # projects to all V heads at once

        # Final projection after concatenating all heads
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: shape [batch, seq_len, d_model]
        returns: shape [batch, seq_len, d_model]
        """
        B, T, C = x.shape  # batch, sequence length, d_model

        # Project into Q, K, V — then reshape for multi-head
        # Each goes from [batch, seq_len, d_model] → [batch, n_heads, seq_len, d_k]
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # ---- THE CORE MATH OF ATTENTION ----
        # scores = Q · Kᵀ / √d_k
        # Shape: [batch, n_heads, seq_len, seq_len]
        # Each entry (i, j) = "how much should token i attend to token j?"
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # CAUSAL MASK: token i can only attend to tokens 0..i (not the future!)
        # This is what makes it autoregressive — you can't cheat by looking ahead
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))  # future → -inf → 0 after softmax

        # Softmax normalizes scores to probabilities (each row sums to 1)
        attn_weights = F.softmax(scores, dim=-1)  # [batch, n_heads, seq_len, seq_len]

        # Weighted sum of values — this IS the attention output
        # Each token's new representation is a mix of all (past) tokens' values,
        # weighted by how relevant they are
        attn_output = attn_weights @ V  # [batch, n_heads, seq_len, d_k]

        # Concatenate all heads back together and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_out(attn_output)


# =============================================================================
# STEP 5: FEED-FORWARD NETWORK
# =============================================================================

class FeedForward(nn.Module):
    """
    Two linear transforms with an activation in between.
    Looks simple, but this is where ~2/3 of all parameters live.

    The expansion to 4× the model dimension creates a massive "bottleneck"
    that the model uses to store and retrieve learned knowledge.

         d_model → 4*d_model → GELU → 4*d_model → d_model
         (4096)    (16384)             (16384)     (4096)
    """
    def __init__(self, d_model: int):
        super().__init__()
        d_ff = 4 * d_model  # standard expansion factor
        self.up = nn.Linear(d_model, d_ff)      # expand
        self.down = nn.Linear(d_ff, d_model)     # compress back

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # up-project, activate (GELU — smooth ReLU), down-project
        return self.down(F.gelu(self.up(x)))


# =============================================================================
# STEP 6 & 7: TRANSFORMER BLOCK (with residual connections) × N LAYERS
# =============================================================================

class TransformerBlock(nn.Module):
    """
    One complete transformer layer. This is where Steps 4-6 come together:

        ┌─────────────────────────────────┐
        │  x  ─────────────────────┐      │
        │  │                       │      │
        │  ├→ LayerNorm → Attention ─→ +  │  ← residual: x + attention(norm(x))
        │  │                       │      │
        │  ├─────────────────────┐ │      │
        │  │                     │ │      │
        │  ├→ LayerNorm → FFN ────→ +     │  ← residual: x + ffn(norm(x))
        │  │                              │
        └─────────────────────────────────┘

    Pre-norm architecture (what GPT uses):
      - LayerNorm BEFORE each sub-block
      - Residual addition AFTER each sub-block
      - The residual stream itself stays unnormalized and free-flowing
    """
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)       # normalize before attention
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)       # normalize before FFN
        self.ffn = FeedForward(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # STEP 6 IN ACTION:
        # The block computes a transformation, then ADDS it to the input.
        # The block only gets to NUDGE the stream, never replace it.

        x = x + self.attn(self.ln1(x))   # residual around attention
        x = x + self.ffn(self.ln2(x))    # residual around FFN
        return x
        # That's it. x = x + f(x). The "+" is the entire trick.


# =============================================================================
# PUTTING IT ALL TOGETHER: THE COMPLETE TRANSFORMER
# =============================================================================

class Transformer(nn.Module):
    """
    The full model: Embed → N × (Attention + FFN with residuals) → Logits

    This is a decoder-only transformer (like GPT / Claude).
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 64,     # small for demo — real models use 4096+
        n_heads: int = 4,
        n_layers: int = 4,     # small for demo — real models use 32-80+
        max_seq_len: int = 128,
    ):
        super().__init__()

        # Steps 2-3: embedding + positional encoding
        self.embedding = Embedding(vocab_size, d_model, max_seq_len)

        # Steps 4-7: stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads) for _ in range(n_layers)
        ])

        # Final layer norm (stabilizes the output)
        self.ln_final = nn.LayerNorm(d_model)

        # Step 8: project from d_model back to vocab_size to get logits
        # This is literally a matrix multiply: [d_model] → [vocab_size]
        # One score per word in the entire vocabulary
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        Full forward pass: integers in, logits out.

        token_ids: [batch, seq_len] — the token integers
        returns:   [batch, seq_len, vocab_size] — raw logits (NOT probabilities)
        """
        # Steps 2-3: embed tokens and add position information
        x = self.embedding(token_ids)

        # Steps 4-7: pass through every transformer layer
        # Each layer reads the residual stream and adds its contribution
        for block in self.blocks:
            x = block(x)

        # Normalize the final residual stream
        x = self.ln_final(x)

        # Step 8: project to vocabulary — one logit per vocab word
        logits = self.lm_head(x)  # [batch, seq_len, vocab_size]
        return logits


# =============================================================================
# STEP 9: SAMPLING — logits → probability distribution → pick a token
# =============================================================================

def sample_next_token(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
) -> int:
    """
    Takes raw logits for the LAST token position and samples a next token.

    logits:      [vocab_size] — one raw score per vocabulary entry
    temperature: controls randomness (0 = greedy, 1 = faithful, >1 = creative)
    top_k:       only consider the top K highest-probability tokens
    top_p:       only consider tokens whose cumulative probability ≤ p

    Returns: a single integer — the sampled token ID
    """
    # --- Temperature ---
    # Divide logits by temperature BEFORE softmax.
    # Low temp → differences are amplified → sharp distribution → nearly deterministic
    # High temp → differences are compressed → flat distribution → more random
    if temperature != 1.0:
        logits = logits / temperature

    # --- Top-K filtering ---
    # Zero out everything except the K most likely tokens
    if top_k > 0:
        top_k_values, _ = torch.topk(logits, top_k)
        min_top_k = top_k_values[-1]
        logits[logits < min_top_k] = float("-inf")

    # --- Top-P (nucleus) filtering ---
    # Keep the smallest set of tokens whose cumulative probability ≥ p
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        # Remove tokens with cumulative probability above the threshold
        remove_mask = cumulative_probs > top_p
        remove_mask[1:] = remove_mask[:-1].clone()  # keep at least one token
        remove_mask[0] = False
        sorted_logits[remove_mask] = float("-inf")
        # Scatter back to original order
        logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)

    # --- Convert to probabilities and sample ---
    probs = F.softmax(logits, dim=-1)   # logits → probabilities (sum to 1)
    token_id = torch.multinomial(probs, 1).item()  # random draw from distribution
    return token_id


# =============================================================================
# STEP 10: AUTOREGRESSIVE LOOP — generate token by token
# =============================================================================

@torch.no_grad()  # no gradients needed during inference
def generate(
    model: Transformer,
    tokenizer: SimpleTokenizer,
    prompt: str,
    max_new_tokens: int = 10,
    temperature: float = 1.0,
    top_k: int = 10,
) -> str:
    """
    The full generation loop:
      1. Tokenize the prompt
      2. Run forward pass to get logits
      3. Sample next token from logits
      4. Append token to sequence
      5. Repeat from step 2

    This is AUTOREGRESSIVE generation — each new token depends on all previous ones.
    """
    model.eval()
    token_ids = tokenizer.encode(prompt)
    token_ids = torch.tensor([token_ids])  # add batch dimension: [1, seq_len]

    print(f"\n{'='*60}")
    print(f"GENERATING from prompt: \"{prompt}\"")
    print(f"{'='*60}\n")

    for step in range(max_new_tokens):
        # FULL FORWARD PASS — every step reruns the entire model
        # (In production, KV-caching avoids redundant computation)
        logits = model(token_ids)  # [1, seq_len, vocab_size]

        # We only care about the LAST position's logits
        # This is the residual stream vector at the final token,
        # projected back to vocabulary space
        last_logits = logits[0, -1, :]  # [vocab_size]

        # Sample from the distribution
        next_id = sample_next_token(last_logits, temperature=temperature, top_k=top_k)

        # Show what the model is "thinking"
        probs = F.softmax(last_logits, dim=-1)
        top_probs, top_ids = torch.topk(probs, 5)
        top_words = [(tokenizer.id_to_word[i.item()], f"{p:.1%}") for i, p in zip(top_ids, top_probs)]
        chosen_word = tokenizer.id_to_word[next_id]
        print(f"  Step {step+1}: top 5 = {top_words}")
        print(f"          sampled → \"{chosen_word}\" (id={next_id})\n")

        # APPEND and loop — this is the autoregressive part
        next_tensor = torch.tensor([[next_id]])
        token_ids = torch.cat([token_ids, next_tensor], dim=1)

    result = tokenizer.decode(token_ids[0].tolist())
    print(f"{'='*60}")
    print(f"RESULT: \"{result}\"")
    print(f"{'='*60}")
    return result


# =============================================================================
# DEMO: Run everything end-to-end
# =============================================================================

if __name__ == "__main__":
    # Build a tiny vocabulary (real models have ~100k tokens)
    vocab = [
        "the", "cat", "sat", "on", "mat", "a", "dog", "ran", "to",
        "big", "small", "red", "blue", "house", "tree", "is", "was",
        "and", "in", "under", "happy", "lazy", "quick", "brown", "fox",
        "jumped", "over", "slept", "ate", "fish", "bird", "sky", "sun",
    ]

    tokenizer = SimpleTokenizer(vocab)
    print(f"Vocabulary size: {tokenizer.vocab_size}")

    # STEP 1: Tokenize
    prompt = "the cat sat on"
    token_ids = tokenizer.encode(prompt)
    print(f"\n--- STEP 1: Tokenize ---")
    print(f"  \"{prompt}\" → {token_ids}")
    print(f"  These are arbitrary integers. 0='the', 1='cat', etc.")

    # Build the model
    model = Transformer(
        vocab_size=tokenizer.vocab_size,
        d_model=64,    # tiny for demo
        n_heads=4,     # 4 attention heads, each 16-dimensional
        n_layers=4,    # 4 transformer blocks
    )

    # Count parameters
    n_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel: {n_params:,} parameters")
    print(f"  (GPT-3 has 175,000,000,000 — same architecture, just bigger numbers)\n")

    # STEPS 2-3: Embed
    ids_tensor = torch.tensor([token_ids])
    embeddings = model.embedding(ids_tensor)
    print(f"--- STEPS 2-3: Embed + Position ---")
    print(f"  Input shape:  {ids_tensor.shape}     (batch=1, seq_len=4)")
    print(f"  Output shape: {embeddings.shape}  (batch=1, seq_len=4, d_model=64)")
    print(f"  Each integer is now a 64-dimensional vector with position info.\n")

    # STEPS 4-7: Forward through all layers
    x = embeddings
    for i, block in enumerate(model.blocks):
        x_before = x.clone()
        x = block(x)
        # Show that residual connections preserve the original signal
        delta = (x - x_before).abs().mean().item()
        print(f"--- STEPS 4-7: Layer {i+1} ---")
        print(f"  Mean |delta| = {delta:.4f} (the 'nudge' this layer added)")

    # STEP 8: Logits
    x_norm = model.ln_final(x)
    logits = model.lm_head(x_norm)
    print(f"\n--- STEP 8: Logits ---")
    print(f"  Residual stream shape: {x_norm.shape}  (batch=1, seq_len=4, d_model=64)")
    print(f"  Logits shape:          {logits.shape}  (batch=1, seq_len=4, vocab_size={tokenizer.vocab_size})")
    print(f"  That's one raw score per vocabulary word, per position.\n")

    # STEP 9: Sample
    last_logits = logits[0, -1, :]  # logits at last position
    probs = F.softmax(last_logits, dim=-1)
    print(f"--- STEP 9: Sample ---")
    print(f"  Raw logits (first 5):   {last_logits[:5].tolist()}")
    print(f"  After softmax (first 5): {probs[:5].tolist()}")
    print(f"  These are probabilities now — they sum to {probs.sum():.4f}\n")

    # STEP 10: Full generation loop
    generate(model, tokenizer, prompt="the cat sat on", max_new_tokens=6, temperature=0.8)

    # Show temperature comparison
    print(f"\n{'='*60}")
    print("TEMPERATURE COMPARISON")
    print(f"{'='*60}")

    for temp in [0.1, 0.5, 1.0, 2.0]:
        print(f"\n  temp={temp}:", end=" ")
        for _ in range(5):
            ids = torch.tensor([tokenizer.encode("the cat sat on")])
            model.eval()
            with torch.no_grad():
                for _ in range(4):
                    logits = model(ids)
                    next_id = sample_next_token(logits[0, -1, :].clone(), temperature=temp, top_k=10)
                    ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1)
            print(f"\"{tokenizer.decode(ids[0].tolist())}\"", end="  ")
    print()
    print("\n  Low temperature  → repetitive / deterministic")
    print("  High temperature → diverse / creative / chaotic")
transformer_training.py Full training pipeline — from random to trained
Python 505 lines
"""
Training a Transformer — From Random Weights to Language Understanding

This implements every step from the TRAINING visualization in real PyTorch.
Run this file and watch a model go from random gibberish to actually
learning next-token prediction on a tiny dataset.

Training steps (mapped to the visualization):
  1. Training Data       — build a small text corpus
  2. Input → Target Pairs — create (input, target) pairs where target = next token
  3. Forward Pass        — run the model, get logits at every position
  4. Cross-Entropy Loss  — measure how wrong the predictions are
  5. Backpropagation     — compute gradients for every weight
  6. Gradient Descent    — nudge every weight to reduce loss
  7. The Training Loop   — repeat steps 2-6 thousands of times
  8. Before vs. After    — compare random model to trained model

Usage:
    pip install torch
    python transformer_training.py

The model architecture is duplicated here (not imported) so this file is
completely self-contained. See transformer_from_scratch.py for detailed
comments on the architecture itself.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random


# =============================================================================
# MODEL ARCHITECTURE (same as inference file, condensed)
# =============================================================================

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_out(out)


class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.up = nn.Linear(d_model, 4 * d_model)
        self.down = nn.Linear(4 * d_model, d_model)

    def forward(self, x):
        return self.down(F.gelu(self.up(x)))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))   # residual around attention
        x = x + self.ffn(self.ln2(x))    # residual around FFN
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4, max_seq_len=128):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.lm_head(x)  # [B, T, vocab_size]


# =============================================================================
# SIMPLE TOKENIZER (same as inference file)
# =============================================================================

class SimpleTokenizer:
    def __init__(self, vocab):
        self.word_to_id = {w: i for i, w in enumerate(vocab)}
        self.id_to_word = {i: w for i, w in enumerate(vocab)}
        self.vocab_size = len(vocab)

    def encode(self, text):
        return [self.word_to_id[w] for w in text.lower().split() if w in self.word_to_id]

    def decode(self, ids):
        return " ".join(self.id_to_word[i] for i in ids)


# =============================================================================
# STEP 1: TRAINING DATA
# =============================================================================
# In reality this would be terabytes of text. Here we use a tiny corpus
# that's small enough to memorize, so we can clearly see training work.

CORPUS = [
    "the cat sat on the mat",
    "the dog sat on the rug",
    "the cat slept on the bed",
    "the dog ran to the house",
    "a big cat sat on a big mat",
    "a small dog ran to a red house",
    "the quick brown fox jumped over the lazy dog",
    "the lazy cat slept under the big tree",
    "a happy bird sat on the tree",
    "the red bird flew over the house",
    "the brown dog slept on the mat",
    "a quick cat ran under the tree",
    "the happy dog jumped over the mat",
    "a lazy bird sat on the red house",
    "the small cat ran to the big tree",
    "a brown fox slept under the mat",
    "the big dog sat on a red rug",
    "a happy cat jumped over the lazy dog",
    "the quick bird flew to the sun",
    "a small fox ran under the big house",
]

# Build vocabulary from the corpus
def build_vocab(corpus):
    words = set()
    for sentence in corpus:
        for word in sentence.lower().split():
            words.add(word)
    return sorted(words)


# =============================================================================
# STEP 2: INPUT → TARGET PAIRS
# =============================================================================
# The training signal: for every token in position i, the target is the
# token at position i+1. We're teaching the model to predict what comes next.

def make_training_pairs(corpus, tokenizer, seq_len=8):
    """
    Convert corpus into (input, target) tensor pairs.

    For the sentence "the cat sat on the mat":
      input  = [the, cat, sat, on,  the]
      target = [cat, sat, on,  the, mat]

    Every position is a training example, all computed simultaneously.
    """
    all_inputs = []
    all_targets = []

    for sentence in corpus:
        ids = tokenizer.encode(sentence)

        # Slide a window across the sentence to create fixed-length chunks
        for start in range(0, len(ids) - 1, seq_len):
            end = min(start + seq_len, len(ids) - 1)
            if end - start < 2:  # need at least 2 tokens
                continue

            inp = ids[start:end]        # tokens 0..N-1
            tgt = ids[start + 1:end + 1]  # tokens 1..N (shifted by 1)

            all_inputs.append(inp)
            all_targets.append(tgt)

    return all_inputs, all_targets


def collate_batch(inputs, targets, batch_size):
    """Group training pairs into batches, padding to equal length."""
    indices = list(range(len(inputs)))
    random.shuffle(indices)

    batches = []
    for i in range(0, len(indices), batch_size):
        batch_idx = indices[i:i + batch_size]
        batch_inp = [inputs[j] for j in batch_idx]
        batch_tgt = [targets[j] for j in batch_idx]

        # Pad to the longest sequence in this batch
        max_len = max(len(s) for s in batch_inp)
        padded_inp = [s + [0] * (max_len - len(s)) for s in batch_inp]
        padded_tgt = [s + [0] * (max_len - len(s)) for s in batch_tgt]

        batches.append((
            torch.tensor(padded_inp),
            torch.tensor(padded_tgt),
        ))

    return batches


# =============================================================================
# STEP 3-6: THE TRAINING STEP (forward, loss, backprop, update — all in one)
# =============================================================================

def train_step(model, optimizer, input_ids, target_ids, step_num, verbose=False):
    """
    One complete training step. This is the heart of training:

      Step 3: Forward pass — run the model, get logits at every position
      Step 4: Loss — compare predictions to targets via cross-entropy
      Step 5: Backprop — compute gradients for all weights
      Step 6: Update — nudge all weights via optimizer
    """
    model.train()

    # ---- STEP 3: FORWARD PASS ----
    # Identical to inference, but we compute logits at ALL positions
    logits = model(input_ids)  # [batch, seq_len, vocab_size]

    # ---- STEP 4: CROSS-ENTROPY LOSS ----
    # Reshape for PyTorch's loss function:
    #   logits: [batch * seq_len, vocab_size]
    #   targets: [batch * seq_len]
    B, T, V = logits.shape
    loss = F.cross_entropy(
        logits.view(B * T, V),    # predicted distribution at each position
        target_ids.view(B * T),   # actual next token at each position
    )
    # loss is a single number: the average "surprise" across all positions.
    # High loss = model was wrong. Low loss = model predicted well.

    # ---- STEP 5: BACKPROPAGATION ----
    # This single call computes ∂loss/∂weight for EVERY weight in the model.
    # It works backward through every layer using the chain rule.
    optimizer.zero_grad()  # clear old gradients
    loss.backward()        # compute new gradients — this IS backprop

    # At this point, every parameter has a .grad attribute:
    if verbose and step_num == 0:
        print("\n  [Backprop] Gradient samples from different layers:")
        for name, param in model.named_parameters():
            if param.grad is not None and ("attn.W_q" in name or "ffn.up" in name or "token_emb" in name):
                g = param.grad
                print(f"    {name:40s}  grad mean={g.mean():.6f}  grad std={g.std():.6f}")

    # ---- STEP 6: GRADIENT DESCENT ----
    # The optimizer adjusts every weight: w = w - lr * gradient
    # (Adam is fancier — it tracks momentum and adapts per-parameter)
    optimizer.step()

    return loss.item()


# =============================================================================
# STEP 7: THE TRAINING LOOP
# =============================================================================

def train(model, tokenizer, corpus, n_epochs=150, lr=3e-3, batch_size=8, seq_len=8):
    """
    The complete training loop. Steps 2-6 repeat for many epochs.

    An 'epoch' = one full pass through all the training data.
    Real LLM training does ~1 epoch (they have so much data they don't repeat).
    Our tiny dataset needs many epochs to learn patterns.
    """
    # Step 2: Prepare all input→target pairs
    inputs, targets = make_training_pairs(corpus, tokenizer, seq_len)
    print(f"  Training pairs: {len(inputs)}")
    print(f"  Example pair:")
    print(f"    input:  {inputs[0]} → \"{tokenizer.decode(inputs[0])}\"")
    print(f"    target: {targets[0]} → \"{tokenizer.decode(targets[0])}\"")
    print(f"    (each position predicts the next token)\n")

    # Step 6: Set up the optimizer (Adam — the standard for transformers)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Step 7: The loop itself
    loss_history = []
    for epoch in range(n_epochs):
        batches = collate_batch(inputs, targets, batch_size)
        epoch_loss = 0.0
        n_batches = 0

        for input_ids, target_ids in batches:
            loss = train_step(
                model, optimizer, input_ids, target_ids,
                step_num=epoch * len(batches) + n_batches,
                verbose=(epoch == 0 and n_batches == 0),
            )
            epoch_loss += loss
            n_batches += 1

        avg_loss = epoch_loss / n_batches
        loss_history.append(avg_loss)

        # Print progress at key milestones
        if epoch == 0 or epoch == 4 or epoch == 9 or epoch % 25 == 24 or epoch == n_epochs - 1:
            bar_len = int(max(0, (avg_loss / 4.0)) * 30)
            bar = "█" * bar_len + "░" * (30 - bar_len)
            print(f"  Epoch {epoch+1:4d}/{n_epochs}  loss={avg_loss:.4f}  [{bar}]")

    return loss_history


# =============================================================================
# STEP 8: BEFORE vs. AFTER — compare random model to trained model
# =============================================================================

@torch.no_grad()
def show_predictions(model, tokenizer, prompt, label="", top_k=5):
    """Show what the model predicts as the next token for a given prompt."""
    model.eval()
    ids = tokenizer.encode(prompt)
    if not ids:
        return
    input_tensor = torch.tensor([ids])
    logits = model(input_tensor)
    last_logits = logits[0, -1, :]  # logits at last position
    probs = F.softmax(last_logits, dim=-1)

    top_probs, top_ids = torch.topk(probs, min(top_k, tokenizer.vocab_size))

    print(f"  {label}\"{prompt}\" → next token predictions:")
    for prob, idx in zip(top_probs, top_ids):
        word = tokenizer.id_to_word[idx.item()]
        bar = "█" * int(prob.item() * 40)
        print(f"    {prob.item():6.1%}  {word:12s} {bar}")
    print()


@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_tokens=8, temperature=0.7):
    """Generate text autoregressively."""
    model.eval()
    ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([ids])

    for _ in range(max_tokens):
        logits = model(input_tensor)
        last_logits = logits[0, -1, :] / temperature
        probs = F.softmax(last_logits, dim=-1)
        next_id = torch.multinomial(probs, 1).item()
        input_tensor = torch.cat([input_tensor, torch.tensor([[next_id]])], dim=1)

    return tokenizer.decode(input_tensor[0].tolist())


# =============================================================================
# MAIN: Run the full training pipeline
# =============================================================================

if __name__ == "__main__":
    print("=" * 64)
    print("  TRANSFORMER TRAINING — FROM RANDOM TO MEANINGFUL")
    print("=" * 64)

    # ---- STEP 1: Training Data ----
    print(f"\n{'─'*64}")
    print("STEP 1: Training Data")
    print(f"{'─'*64}")
    print(f"  Corpus: {len(CORPUS)} sentences")
    print(f"  Sample: \"{CORPUS[0]}\"")
    print(f"  Sample: \"{CORPUS[3]}\"")

    vocab = build_vocab(CORPUS)
    tokenizer = SimpleTokenizer(vocab)
    print(f"  Vocabulary: {tokenizer.vocab_size} words → {vocab}")

    # ---- Build model ----
    model = Transformer(
        vocab_size=tokenizer.vocab_size,
        d_model=64,
        n_heads=4,
        n_layers=4,
        max_seq_len=32,
    )
    n_params = sum(p.numel() for p in model.parameters())
    print(f"\n  Model: {n_params:,} parameters")
    print(f"  (Same architecture as GPT/Claude, just tiny numbers)")

    # ---- STEP 8a: BEFORE training — random weights ----
    print(f"\n{'─'*64}")
    print("STEP 8a: BEFORE Training (random weights)")
    print(f"{'─'*64}")
    show_predictions(model, tokenizer, "the cat sat on", label="[RANDOM] ")
    show_predictions(model, tokenizer, "the dog ran to", label="[RANDOM] ")

    print("  Generated text from random model:")
    for i in range(3):
        text = generate_text(model, tokenizer, "the cat", max_tokens=6)
        print(f"    → \"{text}\"")
    print("  (Pure gibberish — weights are random noise)\n")

    # ---- STEPS 2-7: Train! ----
    print(f"{'─'*64}")
    print("STEPS 2-7: Training Loop")
    print(f"{'─'*64}")

    # Step 2
    print("\n  STEP 2: Creating input → target pairs...")
    loss_history = train(model, tokenizer, CORPUS, n_epochs=150, lr=3e-3)

    # ---- Print loss trajectory ----
    print(f"\n  Loss trajectory:")
    print(f"    Start:  {loss_history[0]:.4f}  (random guessing ≈ -log(1/{tokenizer.vocab_size}) = {math.log(tokenizer.vocab_size):.2f})")
    print(f"    End:    {loss_history[-1]:.4f}")
    print(f"    Change: {loss_history[0] - loss_history[-1]:.4f} reduction")

    # Show the loss curve as ASCII art
    print(f"\n  Loss curve:")
    max_loss = max(loss_history)
    for row in range(10, -1, -1):
        threshold = max_loss * row / 10
        line = "    "
        for i in range(0, len(loss_history), max(1, len(loss_history) // 60)):
            line += "█" if loss_history[i] >= threshold else " "
        print(f"  {threshold:5.2f} |{line}")
    print(f"        └{'─' * 62}")
    print(f"         epoch 1{' ' * 52}epoch {len(loss_history)}")

    # ---- STEP 8b: AFTER training ----
    print(f"\n{'─'*64}")
    print("STEP 8b: AFTER Training (learned weights)")
    print(f"{'─'*64}")
    show_predictions(model, tokenizer, "the cat sat on", label="[TRAINED] ")
    show_predictions(model, tokenizer, "the dog ran to", label="[TRAINED] ")
    show_predictions(model, tokenizer, "a big cat sat on", label="[TRAINED] ")

    print("  Generated text from trained model:")
    for i in range(5):
        text = generate_text(model, tokenizer, "the cat", max_tokens=6, temperature=0.7)
        print(f"    → \"{text}\"")
    print()
    for i in range(5):
        text = generate_text(model, tokenizer, "a small", max_tokens=6, temperature=0.7)
        print(f"    → \"{text}\"")

    # ---- Side-by-side comparison ----
    print(f"\n{'─'*64}")
    print("SIDE-BY-SIDE: What Changed?")
    print(f"{'─'*64}")
    print("""
    BEFORE (random)              AFTER (trained)
    ─────────────────            ─────────────────
    Every word ≈ equal prob      "mat" / "rug" are likely
    No pattern recognition       Knows "sat on" → surface
    Generates gibberish          Generates coherent phrases
    Loss ≈ {:.2f}                 Loss ≈ {:.2f}

    Same architecture. Same math. Same code.
    The ONLY difference: the values stored in the weight matrices.
    Training sculpted random numbers into understanding.
    """.format(loss_history[0], loss_history[-1]))

    # ---- Peek inside the weights ----
    print(f"{'─'*64}")
    print("BONUS: Peek Inside the Learned Weights")
    print(f"{'─'*64}")

    # Show that similar words ended up with similar embeddings
    print("\n  Embedding similarity (cosine) between words:")
    def cos_sim(w1, w2):
        id1 = tokenizer.word_to_id[w1]
        id2 = tokenizer.word_to_id[w2]
        e1 = model.token_emb.weight[id1]
        e2 = model.token_emb.weight[id2]
        return F.cosine_similarity(e1.unsqueeze(0), e2.unsqueeze(0)).item()

    pairs = [("cat", "dog"), ("cat", "bird"), ("sat", "slept"),
             ("big", "small"), ("cat", "mat"), ("on", "under")]
    for w1, w2 in pairs:
        sim = cos_sim(w1, w2)
        bar = "█" * int(max(0, (sim + 1) / 2 * 20))
        print(f"    {w1:8s} ↔ {w2:8s}  sim={sim:+.3f}  {bar}")

    print("""
    cat↔dog are similar (both animals in similar contexts)
    sat↔slept are similar (both actions: "X ___ on the Y")
    cat↔mat are less similar (different roles despite co-occurring)

    These relationships EMERGED from training. Nobody programmed them.
    The model discovered that similar words appear in similar contexts
    and organized its embedding space accordingly.
    """)
qkv_walkthrough.py Concrete Q/K/V computation with real numbers
Python 349 lines
"""
Q, K, V Walkthrough — Grounded in Our Code

This traces through CausalSelfAttention from transformer_from_scratch.py
with real numbers, showing exactly what Q, K, V are and how they interact.

Our model config:
  d_model  = 64   (each token is a 64-dim vector)
  n_heads  = 4    (4 parallel attention heads)
  d_k      = 16   (64 / 4 = 16 dims per head)

Input sentence: "The cat sat on the"
                  pos0 pos1 pos2 pos3 pos4

Run with:
    python qkv_walkthrough.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)


# =====================================================================
# Setup: Build the pieces from our inference code
# =====================================================================

vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
                "a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}

d_model = 64
n_heads = 4
d_k = d_model // n_heads  # 16

# Embedding (token + position) — same as our code
token_emb = nn.Embedding(len(vocab), d_model)
pos_emb = nn.Embedding(32, d_model)

# The attention layer — this is what we're tracing through
attn = nn.modules.linear.Linear  # just for reference
W_q = nn.Linear(d_model, d_model)  # 64→64, contains 64×64=4096 learned weights
W_k = nn.Linear(d_model, d_model)  # 64→64, another 4096 learned weights
W_v = nn.Linear(d_model, d_model)  # 64→64, another 4096 learned weights
W_out = nn.Linear(d_model, d_model)


# =====================================================================
# THE WALKTHROUGH
# =====================================================================

print("=" * 70)
print("  Q, K, V WALKTHROUGH")
print("  Tracing through CausalSelfAttention with real numbers")
print("=" * 70)


# ---- Step 0: Input tokens become vectors ----
sentence = "the cat sat on the"
token_ids = torch.tensor([[word_to_id[w] for w in sentence.split()]])
B, T = token_ids.shape  # B=1 batch, T=5 tokens

print(f"\n{'─'*70}")
print(f"BEFORE ATTENTION: Input vectors")
print(f"{'─'*70}")
print(f"  Sentence: \"{sentence}\"")
print(f"  Token IDs: {token_ids[0].tolist()}")

# Embed tokens + positions
x = token_emb(token_ids) + pos_emb(torch.arange(T))
# x shape: [1, 5, 64] — 5 tokens, each a 64-dimensional vector

print(f"  Input x shape: {list(x.shape)} — {T} tokens, each a {d_model}-dim vector")
print(f"\n  Each token is now a vector of {d_model} numbers:")
for i, word in enumerate(sentence.split()):
    vec = x[0, i]
    print(f"    pos {i} \"{word:4s}\" → [{vec[0]:.3f}, {vec[1]:.3f}, {vec[2]:.3f}, ... {vec[-1]:.3f}]")


# ---- Step 1: Project into Q, K, V ----
# This is the CRITICAL step. The same input x gets multiplied by
# three DIFFERENT learned weight matrices to produce three different
# sets of vectors.

print(f"\n{'─'*70}")
print(f"STEP 1: Project into Q, K, V")
print(f"{'─'*70}")
print(f"""
  Our code (line 142-144):
    Q = self.W_q(x)    # x × W_q  — "what am I looking for?"
    K = self.W_k(x)    # x × W_k  — "what do I advertise as?"  
    V = self.W_v(x)    # x × W_v  — "what content do I provide?"

  Same input x, three different weight matrices, three different outputs.
  Each weight matrix is {d_model}×{d_model} = {d_model*d_model} learned parameters.
""")

Q_full = W_q(x)  # [1, 5, 64]
K_full = W_k(x)  # [1, 5, 64]
V_full = W_v(x)  # [1, 5, 64]

print(f"  For token 'sat' (pos 2):")
print(f"    Input vector:  [{x[0,2,0]:.3f}, {x[0,2,1]:.3f}, {x[0,2,2]:.3f}, ...]")
print(f"    × W_q  →  Q:  [{Q_full[0,2,0]:.3f}, {Q_full[0,2,1]:.3f}, {Q_full[0,2,2]:.3f}, ...]")
print(f"    × W_k  →  K:  [{K_full[0,2,0]:.3f}, {K_full[0,2,1]:.3f}, {K_full[0,2,2]:.3f}, ...]")
print(f"    × W_v  →  V:  [{V_full[0,2,0]:.3f}, {V_full[0,2,1]:.3f}, {V_full[0,2,2]:.3f}, ...]")
print(f"""
  Notice: Q, K, V are all DIFFERENT vectors derived from the same input.
  - The Q for "sat" encodes what "sat" is looking for (maybe: a subject? a location?)
  - The K for "sat" encodes how "sat" advertises itself (maybe: past-tense verb, positional)
  - The V for "sat" encodes the information "sat" provides (maybe: action implies surface)
  
  These differences are entirely learned during training. W_q, W_k, W_v are
  trained so that Q·K similarities are high for tokens that SHOULD attend to
  each other, and the V vectors carry the information that's useful to pass forward.
""")


# ---- Step 2: Reshape for multi-head attention ----
print(f"{'─'*70}")
print(f"STEP 2: Split into {n_heads} heads")
print(f"{'─'*70}")

Q = Q_full.view(B, T, n_heads, d_k).transpose(1, 2)  # [1, 4, 5, 16]
K = K_full.view(B, T, n_heads, d_k).transpose(1, 2)
V = V_full.view(B, T, n_heads, d_k).transpose(1, 2)

print(f"""
  Our code (line 142, the .view().transpose() part):
    Q = W_q(x).view(B, T, n_heads, d_k).transpose(1, 2)
    
  This reshapes from [1, 5, 64] → [1, 4, 5, 16]
                       ↑  ↑  ↑      ↑  ↑  ↑  ↑
                    batch T d_model  batch heads T d_per_head

  The 64-dim vector is split into 4 heads of 16 dims each.
  Each head operates independently — like 4 parallel attention mechanisms,
  each looking at a different 16-dimensional "view" of the data.

  Head 1 might learn to track syntax (subject-verb agreement)
  Head 2 might learn to track position (what's nearby)
  Head 3 might learn to track semantic role (agent, patient, location)
  Head 4 might learn to track something else entirely
""")

print(f"  Q shape: {list(Q.shape)} — {n_heads} heads, each processing {T} tokens of {d_k} dims")


# ---- Step 3: Compute attention scores ----
print(f"\n{'─'*70}")
print(f"STEP 3: Compute attention scores (Q · Kᵀ / √d_k)")
print(f"{'─'*70}")

scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
# scores shape: [1, 4, 5, 5] — for each head, a 5×5 matrix

print(f"""
  Our code (line 150):
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
    
  This is the "library lookup" — each token's Query is compared against
  every token's Key by dot product. High dot product = high relevance.
  
  scores shape: {list(scores.shape)} — for each of {n_heads} heads, a {T}×{T} matrix
  Entry (i, j) = "how much should token i attend to token j?"
""")

# Show the score matrix for head 0
words = sentence.split()
print(f"  Score matrix for Head 1 (before masking):")
print(f"  {'':12s}", end="")
for w in words:
    print(f"{w:>8s}", end="")
print()
for i, w in enumerate(words):
    print(f"  {w:>10s}  ", end="")
    for j in range(T):
        print(f"{scores[0, 0, i, j].item():8.2f}", end="")
    print()


# ---- Step 4: Causal mask ----
print(f"\n{'─'*70}")
print(f"STEP 4: Apply causal mask (can't look at future tokens)")
print(f"{'─'*70}")

mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores_masked = scores.clone()
scores_masked = scores_masked.masked_fill(mask, float("-inf"))

print(f"""
  Our code (lines 154-155):
    mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
    scores = scores.masked_fill(mask, float("-inf"))
    
  Future positions → -inf → 0 after softmax.
  Token "sat" (pos 2) can look at "the" (0), "cat" (1), "sat" (2)
  but NOT "on" (3) or "the" (4).
""")

print(f"  Score matrix for Head 1 (after masking):")
print(f"  {'':12s}", end="")
for w in words:
    print(f"{w:>8s}", end="")
print()
for i, w in enumerate(words):
    print(f"  {w:>10s}  ", end="")
    for j in range(T):
        val = scores_masked[0, 0, i, j].item()
        if val == float("-inf"):
            print(f"{'−∞':>8s}", end="")
        else:
            print(f"{val:8.2f}", end="")
    print()


# ---- Step 5: Softmax → attention weights ----
print(f"\n{'─'*70}")
print(f"STEP 5: Softmax → attention weights (each row sums to 1)")
print(f"{'─'*70}")

attn_weights = F.softmax(scores_masked, dim=-1)

print(f"""
  Our code (line 158):
    attn_weights = F.softmax(scores, dim=-1)
    
  Converts raw scores to probabilities. Each row sums to 1.0.
  These weights tell us: for each token, what fraction of its
  attention goes to each other token?
""")

print(f"  Attention weights for Head 1:")
print(f"  {'':12s}", end="")
for w in words:
    print(f"{w:>8s}", end="")
print("    interpretation")
for i, w in enumerate(words):
    print(f"  {w:>10s}  ", end="")
    max_j = 0
    max_val = 0
    for j in range(T):
        val = attn_weights[0, 0, i, j].item()
        if val > max_val:
            max_val = val
            max_j = j
        print(f"{val:8.1%}", end="")
    print(f"    ← mostly attends to \"{words[max_j]}\"")


# ---- Step 6: Weighted sum of Values ----
print(f"\n{'─'*70}")
print(f"STEP 6: Weighted sum of Values → attention output")
print(f"{'─'*70}")

attn_output = attn_weights @ V  # [1, 4, 5, 16]

print(f"""
  Our code (line 163):
    attn_output = attn_weights @ V
    
  This is the punchline. Each token's output is a weighted blend of
  all attending tokens' VALUE vectors, using the attention weights.

  For "the" at position 4 (Head 1):
""")

# Show the blending for position 4 (the last "the")
print(f"    output[pos4] = ", end="")
for j, w in enumerate(words):
    weight = attn_weights[0, 0, 4, j].item()
    if weight > 0.01:
        print(f"{weight:.1%} × V(\"{w}\")", end="")
        if j < T - 1 and any(attn_weights[0, 0, 4, k].item() > 0.01 for k in range(j+1, T)):
            print(f"  +  ", end="")
print()
print(f"""
  The output vector for "the" at position 4 now contains information
  from the tokens it attended to — blended according to relevance.
  This is how information flows between tokens in the sequence.
""")


# ---- Step 7: Concatenate heads and project ----
print(f"{'─'*70}")
print(f"STEP 7: Concatenate all heads → final output")
print(f"{'─'*70}")

# Concatenate heads: [1, 4, 5, 16] → [1, 5, 64]
concat = attn_output.transpose(1, 2).contiguous().view(B, T, d_model)
final_output = W_out(concat)

print(f"""
  Our code (lines 166-167):
    attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
    return self.W_out(attn_output)
    
  4 heads of 16 dims → concatenated to 64 dims → one more projection.
  
  attn_output: {list(attn_output.shape)} → concat: {list(concat.shape)} → final: {list(final_output.shape)}
  
  Each head's 16-dim output is concatenated, then W_out mixes them
  together. This lets the model combine information across heads.
""")


# ---- Summary: What happened ----
print(f"{'─'*70}")
print(f"SUMMARY: What attention did to each token")
print(f"{'─'*70}")

for i, w in enumerate(words):
    input_vec = x[0, i]
    output_vec = final_output[0, i]
    delta = (output_vec - input_vec).norm().item()
    cosine = F.cosine_similarity(input_vec.unsqueeze(0), output_vec.unsqueeze(0)).item()
    print(f"  \"{w}\" (pos {i}):")
    print(f"    Input:  [{input_vec[0]:.3f}, {input_vec[1]:.3f}, {input_vec[2]:.3f}, ...]")
    print(f"    Output: [{output_vec[0]:.3f}, {output_vec[1]:.3f}, {output_vec[2]:.3f}, ...]")
    print(f"    Change: ‖delta‖ = {delta:.3f},  cosine similarity = {cosine:.3f}")

    # Show what this token attended to most (averaged across heads)
    avg_attn = attn_weights[0, :, i, :].mean(dim=0)  # average across 4 heads
    top_vals, top_idx = torch.topk(avg_attn, min(3, i + 1))
    attended = [f"\"{words[j]}\"({v:.0%})" for j, v in zip(top_idx.tolist(), top_vals.tolist())]
    print(f"    Attended to: {', '.join(attended)}")
    print()

print(f"""
  KEY INSIGHT: The output vectors are DIFFERENT from the input vectors.
  Each token has been enriched with information from the tokens it
  attended to. "the" at position 4 now carries information about
  "cat", "sat", "on" — not just the word "the" in isolation.

  This is the output of ONE attention layer. This enriched vector then
  goes to the FFN (step 5), gets added back via residual connection
  (step 6), and then the NEXT layer's attention does the same thing
  again — but now working with already-enriched vectors.

  IMPORTANT: The W_q, W_k, W_v, W_out weight matrices are FROZEN at
  inference time. They were learned during training. What we just
  traced is what happens every time this layer processes any input.
  The weights don't change — only the input vectors change.
""")
qkv_training_connection.py How W_q/W_k/W_v specialize via backprop
Python 340 lines
"""
How W_q, W_k, W_v Get Trained — Connecting Inference to Training

The punchline: there is NO special code for training attention weights.
Two lines update EVERY weight in the model — W_q, W_k, W_v, W_out,
FFN weights, embeddings, everything:

    loss.backward()     # compute gradients for ALL weights
    optimizer.step()    # nudge ALL weights

This file proves it by showing the actual W_q weights before and after
a single training step, and printing the gradients that drive the change.

Run with:
    python qkv_training_connection.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)


# =====================================================================
# The model (condensed from transformer_from_scratch.py)
# =====================================================================

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)   # ← THESE are what we're tracking
        self.W_k = nn.Linear(d_model, d_model)   # ← 
        self.W_v = nn.Linear(d_model, d_model)   # ←
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_out(out)

class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.up = nn.Linear(d_model, 4 * d_model)
        self.down = nn.Linear(4 * d_model, d_model)
    def forward(self, x):
        return self.down(F.gelu(self.up(x)))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4, max_seq_len=32):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
    def forward(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
        for block in self.blocks:
            x = block(x)
        return self.lm_head(self.ln_f(x))


# =====================================================================
# Setup
# =====================================================================

vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
                "a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}

model = Transformer(vocab_size=len(vocab), d_model=64, n_heads=4, n_layers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)


# =====================================================================
# THE WALKTHROUGH
# =====================================================================

print("=" * 70)
print("  HOW W_q, W_k, W_v GET TRAINED")
print("=" * 70)

# ---- Show all named parameters ----
print(f"\n{'─'*70}")
print(f"FIRST: What parameters does the model have?")
print(f"{'─'*70}")
print(f"\n  model.named_parameters() lists EVERY learnable weight:\n")

total = 0
attn_params = 0
for name, param in model.named_parameters():
    n = param.numel()
    total += n
    is_attn = "W_q" in name or "W_k" in name or "W_v" in name or "W_out" in name
    if is_attn:
        attn_params += n
    marker = " ← Q/K/V WEIGHT" if ("W_q" in name or "W_k" in name or "W_v" in name) else ""
    print(f"    {name:45s}  shape {str(list(param.shape)):18s}  ({n:>6,} params){marker}")

print(f"\n  Total: {total:,} parameters")
print(f"  Attention (Q/K/V/Out): {attn_params:,} parameters ({attn_params/total:.0%} of total)")

print(f"""
  Notice: W_q, W_k, W_v appear in every block (0, 1, 2, 3).
  They are regular nn.Linear layers — nothing special about them.
  PyTorch tracks them automatically because they're attributes of nn.Module.
""")


# ---- Snapshot W_q before training ----
print(f"{'─'*70}")
print(f"BEFORE TRAINING: Snapshot of Layer 0's W_q weights")
print(f"{'─'*70}")

# Get the W_q weight matrix from block 0
wq_before = model.blocks[0].attn.W_q.weight.data.clone()
print(f"\n  model.blocks[0].attn.W_q.weight — shape: {list(wq_before.shape)}")
print(f"  This is a {wq_before.shape[0]}×{wq_before.shape[1]} matrix of learned numbers.\n")
print(f"  First 5×5 corner (randomly initialized):")
for i in range(5):
    print(f"    [{', '.join(f'{wq_before[i,j]:+.4f}' for j in range(5))}, ...]")
print(f"\n  These are RANDOM — the model hasn't learned anything yet.")


# ---- The training input ----
print(f"\n{'─'*70}")
print(f"THE TRAINING EXAMPLE")
print(f"{'─'*70}")

sentence = "the cat sat on the mat"
words = sentence.split()
input_ids = torch.tensor([[word_to_id[w] for w in words[:-1]]])   # "the cat sat on the"
target_ids = torch.tensor([[word_to_id[w] for w in words[1:]]])   # "cat sat on the mat"

print(f"\n  Sentence: \"{sentence}\"")
print(f"  Input:    {[id_to_word[i] for i in input_ids[0].tolist()]}")
print(f"  Target:   {[id_to_word[i] for i in target_ids[0].tolist()]}")
print(f"  (Each input token should predict the next token)")


# ---- STEP 3: Forward pass ----
print(f"\n{'─'*70}")
print(f"STEP 3: Forward pass — the input flows through W_q, W_k, W_v")
print(f"{'─'*70}")

model.train()
logits = model(input_ids)

print(f"""
  logits = model(input_ids)

  What just happened inside:
    1. input_ids → embedding layer → 5 vectors of dim 64
    2. Those vectors enter block 0:
       a. LayerNorm
       b. CausalSelfAttention:
          Q = W_q(x)    ← the vector was multiplied by W_q's 64×64 matrix
          K = W_k(x)    ← same vector, different matrix
          V = W_v(x)    ← same vector, third matrix
          scores = Q · Kᵀ / √16
          attention_weights = softmax(masked scores)
          output = attention_weights · V
          output = W_out(output)
       c. Residual addition: x = x + attention_output
       d. FFN (expand → GELU → compress)
       e. Residual addition: x = x + ffn_output
    3. Repeat for blocks 1, 2, 3
    4. Final LayerNorm → lm_head → logits

  W_q was USED in the forward pass but NOT CHANGED.
  At this point, W_q is still the exact same random matrix.
""")


# ---- STEP 4: Loss ----
print(f"{'─'*70}")
print(f"STEP 4: Cross-entropy loss")
print(f"{'─'*70}")

B, T, V = logits.shape
loss = F.cross_entropy(logits.view(B * T, V), target_ids.view(B * T))

print(f"\n  loss = F.cross_entropy(logits, targets)")
print(f"  loss = {loss.item():.4f}")
print(f"\n  This single number measures: 'how wrong was the model across all 5 positions?'")
print(f"  High loss = model's W_q/W_k/W_v produced bad attention patterns")
print(f"  that led to bad predictions.")


# ---- STEP 5: Backpropagation ----
print(f"\n{'─'*70}")
print(f"STEP 5: loss.backward() — THIS is where W_q gets its gradient")
print(f"{'─'*70}")

# Before backward, there are no gradients
print(f"\n  Before loss.backward():")
print(f"    model.blocks[0].attn.W_q.weight.grad = {model.blocks[0].attn.W_q.weight.grad}")

optimizer.zero_grad()
loss.backward()

# Now every parameter has a gradient!
print(f"\n  After loss.backward():")

wq_grad = model.blocks[0].attn.W_q.weight.grad
wk_grad = model.blocks[0].attn.W_k.weight.grad
wv_grad = model.blocks[0].attn.W_v.weight.grad

print(f"    W_q.grad shape: {list(wq_grad.shape)} — a gradient for EVERY weight in W_q")
print(f"    W_k.grad shape: {list(wk_grad.shape)}")
print(f"    W_v.grad shape: {list(wv_grad.shape)}")

print(f"\n  W_q gradient — first 5×5 corner:")
for i in range(5):
    print(f"    [{', '.join(f'{wq_grad[i,j]:+.6f}' for j in range(5))}, ...]")

print(f"""
  WHAT IS THIS GRADIENT?

  Each number answers: "if I increased this specific weight by a tiny amount,
  how much would the loss change?"

  Positive gradient (+0.003) → increasing this weight increases loss → DECREASE it
  Negative gradient (-0.005) → increasing this weight decreases loss → INCREASE it
  Near-zero gradient (0.0001) → this weight barely matters for this example

  loss.backward() computed this by tracing the chain rule BACKWARD:
    loss ← logits ← lm_head ← block3 ← ... ← block0.attn ← W_q × input
                                                               ↑
                                                          ∂loss/∂W_q

  The gradient flows backward through every operation, all the way from
  the loss to W_q. The residual connections act as a "gradient highway"
  ensuring the signal stays strong even for W_q in the very first layer.
""")

# Show gradients for ALL attention parameters
print(f"  Gradient statistics for all Q/K/V weights in the model:")
print(f"  {'Parameter':45s} {'grad mean':>12s} {'grad std':>12s}")
print(f"  {'─'*45} {'─'*12} {'─'*12}")
for name, param in model.named_parameters():
    if param.grad is not None and ("W_q" in name or "W_k" in name or "W_v" in name):
        g = param.grad
        print(f"  {name:45s} {g.mean():+12.6f} {g.std():12.6f}")

print(f"""
  Every W_q, W_k, W_v in every layer got a gradient. No special code —
  PyTorch's autograd traced the computation graph automatically.
""")


# ---- STEP 6: Gradient descent ----
print(f"{'─'*70}")
print(f"STEP 6: optimizer.step() — W_q weights actually change")
print(f"{'─'*70}")

print(f"\n  W_q BEFORE optimizer.step() (first 5 values of row 0):")
print(f"    [{', '.join(f'{model.blocks[0].attn.W_q.weight.data[0,j]:+.6f}' for j in range(5))}]")

optimizer.step()

print(f"\n  W_q AFTER optimizer.step() (first 5 values of row 0):")
print(f"    [{', '.join(f'{model.blocks[0].attn.W_q.weight.data[0,j]:+.6f}' for j in range(5))}]")

wq_after = model.blocks[0].attn.W_q.weight.data.clone()
delta = (wq_after - wq_before).abs()

print(f"\n  The difference (first 5):")
print(f"    [{', '.join(f'{delta[0,j]:.6f}' for j in range(5))}]")
print(f"\n  Average change per weight: {delta.mean():.6f}")
print(f"  Max change:                {delta.max():.6f}")

print(f"""
  TINY changes — on the order of 0.001. But this happens for:
    • Every weight in W_q ({wq_before.numel():,} weights)
    • Every weight in W_k ({wq_before.numel():,} weights)  
    • Every weight in W_v ({wq_before.numel():,} weights)
    • Every weight in W_out, FFN, embeddings, etc.
    • Total: {sum(p.numel() for p in model.parameters()):,} weights updated simultaneously

  And then we do this AGAIN for the next batch, and the next, and the next,
  for millions of batches. Each tiny nudge accumulates. After enough steps,
  W_q has been sculpted from random noise into a matrix that projects tokens
  into a Query space where meaningful similarity comparisons emerge.
""")


# ---- The connection ----
print(f"{'─'*70}")
print(f"THE FULL PICTURE: Inference ← Training")
print(f"{'─'*70}")
print(f"""
  TRAINING (what we just traced):
    1. Forward pass: input → W_q/W_k/W_v → attention → logits
    2. Loss: how wrong were the predictions?
    3. loss.backward(): compute ∂loss/∂W_q, ∂loss/∂W_k, ∂loss/∂W_v
    4. optimizer.step(): nudge each weight to reduce loss
    5. Repeat millions of times

  INFERENCE (the QKV walkthrough):
    1. Forward pass: input → W_q/W_k/W_v → attention → logits → sample
    2. That's it. Same matrices, same computation, no updates.

  The W_q matrix used at inference IS the W_q matrix that was nudged
  millions of times during training. Every entry in that 64×64 matrix
  was shaped by gradient descent to produce Query vectors where the
  right tokens attend to each other.

  There is no separate "attention training" or "Q/K/V learning" step.
  It's all just: forward pass, loss, backward, step. The entire model —
  embeddings, attention, FFN, output head — learns together, end to end,
  from one unified signal: "predict the next token better."
""")
kv_cache_walkthrough.py KV cache: cached vs uncached inference
Python 308 lines
"""
KV Cache — Concrete Example

Shows exactly what gets cached and reused during autoregressive generation.
We trace through generating 3 tokens, showing how K and V vectors
accumulate in the cache and avoid redundant computation.

Run with:
    python kv_cache_walkthrough.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)


# =====================================================================
# Minimal setup
# =====================================================================

vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
                "a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}

d_model = 64
n_heads = 4
d_k = d_model // n_heads  # 16

# One attention layer's components
token_emb = nn.Embedding(len(vocab), d_model)
pos_emb = nn.Embedding(32, d_model)
W_q = nn.Linear(d_model, d_model)
W_k = nn.Linear(d_model, d_model)
W_v = nn.Linear(d_model, d_model)
W_out = nn.Linear(d_model, d_model)


def embed(token_ids, positions):
    """Embed tokens with position info."""
    return token_emb(token_ids) + pos_emb(positions)


def attention_full(x):
    """Standard attention — recompute Q, K, V for ALL tokens every time."""
    B, T, C = x.shape
    Q = W_q(x).view(B, T, n_heads, d_k).transpose(1, 2)
    K = W_k(x).view(B, T, n_heads, d_k).transpose(1, 2)
    V = W_v(x).view(B, T, n_heads, d_k).transpose(1, 2)

    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
    mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
    scores = scores.masked_fill(mask, float("-inf"))
    attn = F.softmax(scores, dim=-1)
    out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
    return W_out(out), K, V


def attention_cached(x_new, pos, kv_cache):
    """
    Cached attention — only compute Q, K, V for the NEW token.
    Reuse cached K, V from all previous tokens.
    """
    B, T_new, C = x_new.shape  # T_new = 1 (just the new token)

    # Only compute Q, K, V for the new token
    Q_new = W_q(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2)  # [1, 4, 1, 16]
    K_new = W_k(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2)  # [1, 4, 1, 16]
    V_new = W_v(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2)  # [1, 4, 1, 16]

    # Append new K, V to cache
    if kv_cache is not None:
        K_cached, V_cached = kv_cache
        K_full = torch.cat([K_cached, K_new], dim=2)  # [1, 4, T_old+1, 16]
        V_full = torch.cat([V_cached, V_new], dim=2)  # [1, 4, T_old+1, 16]
    else:
        K_full = K_new
        V_full = V_new

    # Attention: new token's Q against ALL K's (cached + new)
    T_full = K_full.shape[2]
    scores = (Q_new @ K_full.transpose(-2, -1)) / math.sqrt(d_k)  # [1, 4, 1, T_full]
    # No mask needed — the new token is always the LAST position,
    # so it can attend to everything before it (all cached) + itself
    attn = F.softmax(scores, dim=-1)  # [1, 4, 1, T_full]
    out = (attn @ V_full).transpose(1, 2).contiguous().view(B, 1, C)  # [1, 1, 64]

    return W_out(out), (K_full, V_full)


# =====================================================================
# THE WALKTHROUGH
# =====================================================================

print("=" * 70)
print("  KV CACHE WALKTHROUGH")
print("=" * 70)

prompt = "the cat sat on the"
prompt_ids = [word_to_id[w] for w in prompt.split()]

print(f"\n  Prompt: \"{prompt}\"")
print(f"  Token IDs: {prompt_ids}")
print(f"  We'll generate 3 tokens after this prompt.\n")


# =====================================================================
# METHOD 1: No cache (recompute everything each time)
# =====================================================================

print(f"{'─'*70}")
print(f"  METHOD 1: WITHOUT KV CACHE")
print(f"{'─'*70}")

generated_no_cache = list(prompt_ids)
total_qkv_ops_no_cache = 0

for gen_step in range(3):
    T = len(generated_no_cache)
    ids = torch.tensor([generated_no_cache])
    positions = torch.arange(T)
    x = embed(ids, positions)  # [1, T, 64]

    out, K, V = attention_full(x)

    # Count Q, K, V computations (per token)
    qkv_ops = T * 3  # Q, K, V each computed for all T tokens
    total_qkv_ops_no_cache += qkv_ops

    # Sample next token (just take argmax for simplicity)
    # In real inference, only the last position's output matters
    last_vec = out[0, -1, :]  # [64] — last position only

    # Fake a simple "prediction" by picking a token
    # (we're not running the full model, just demonstrating the cache)
    next_token = (gen_step + 6) % len(vocab)  # deterministic for demo
    generated_no_cache.append(next_token)

    print(f"\n  Step {gen_step + 1}: generating token {gen_step + 6}")
    print(f"    Sequence length: {T} tokens")
    print(f"    Q computed for: ALL {T} tokens  ({T} × W_q multiply)")
    print(f"    K computed for: ALL {T} tokens  ({T} × W_k multiply)")
    print(f"    V computed for: ALL {T} tokens  ({T} × W_v multiply)")
    print(f"    Total Q/K/V ops this step: {qkv_ops}")
    print(f"    But we only USE position {T-1}'s output (the last one)")
    print(f"    Positions 0–{T-2}'s Q, K, V → same as last step, wasted work!")

print(f"\n  Total Q/K/V operations: {total_qkv_ops_no_cache}")


# =====================================================================
# METHOD 2: With KV cache
# =====================================================================

print(f"\n{'─'*70}")
print(f"  METHOD 2: WITH KV CACHE")
print(f"{'─'*70}")

# First: process the full prompt (no cache yet)
T = len(prompt_ids)
ids = torch.tensor([prompt_ids])
x = embed(ids, torch.arange(T))
out, K_initial, V_initial = attention_full(x)

kv_cache = (K_initial, V_initial)  # Cache K and V from the prompt

print(f"\n  Initial: process full prompt ({T} tokens)")
print(f"    Q/K/V computed for all {T} tokens (no cache yet)")
print(f"    K cache shape: {list(kv_cache[0].shape)}")
print(f"    → {kv_cache[0].shape[2]} tokens × {n_heads} heads × {d_k} dims cached")

total_qkv_ops_cache = T * 3  # Initial prompt processing
generated_cached = list(prompt_ids)

for gen_step in range(3):
    # Only embed the ONE new token
    new_token_id = (gen_step + 6) % len(vocab)
    new_pos = len(generated_cached)
    generated_cached.append(new_token_id)

    x_new = embed(
        torch.tensor([[new_token_id]]),   # [1, 1] — just one token
        torch.tensor([new_pos]),          # its position
    )  # [1, 1, 64]

    out_new, kv_cache = attention_cached(x_new, new_pos, kv_cache)

    qkv_ops = 1 * 3  # Q, K, V computed for just 1 token!
    total_qkv_ops_cache += qkv_ops

    K_cached, V_cached = kv_cache
    T_total = K_cached.shape[2]

    print(f"\n  Step {gen_step + 1}: generating token {gen_step + 6}")
    print(f"    New token embedded: {id_to_word.get(new_token_id, '?')} at position {new_pos}")
    print(f"    Q computed for: 1 token   (just the new one)")
    print(f"    K computed for: 1 token   (just the new one)")
    print(f"    V computed for: 1 token   (just the new one)")
    print(f"    K cache: {T_total - 1} old + 1 new = {T_total} total")
    print(f"    V cache: {T_total - 1} old + 1 new = {T_total} total")
    print(f"    Attention: new Q[1,4,1,16] @ all K[1,4,{T_total},16]^T = scores[1,4,1,{T_total}]")
    print(f"    → new token attends to ALL {T_total} tokens (using cached K's)")
    print(f"    Total Q/K/V ops this step: {qkv_ops}")

print(f"\n  Total Q/K/V operations: {total_qkv_ops_cache}")


# =====================================================================
# Comparison
# =====================================================================

print(f"\n{'─'*70}")
print(f"  COMPARISON")
print(f"{'─'*70}")
print(f"""
  Generating 3 tokens after a 5-token prompt:

  Without cache:
    Step 1: compute Q/K/V for 5 tokens  =  15 ops
    Step 2: compute Q/K/V for 6 tokens  =  18 ops
    Step 3: compute Q/K/V for 7 tokens  =  21 ops
    Total:                                 {total_qkv_ops_no_cache} ops

  With cache:
    Initial: compute Q/K/V for 5 tokens =  15 ops  (prompt, once)
    Step 1: compute Q/K/V for 1 token   =   3 ops  (reuse 5 cached)
    Step 2: compute Q/K/V for 1 token   =   3 ops  (reuse 6 cached)
    Step 3: compute Q/K/V for 1 token   =   3 ops  (reuse 7 cached)
    Total:                                 {total_qkv_ops_cache} ops

  Savings: {total_qkv_ops_no_cache - total_qkv_ops_cache} fewer Q/K/V computations ({(1 - total_qkv_ops_cache/total_qkv_ops_no_cache):.0%} less)

  This gap grows with sequence length. Generating 500 tokens:
    Without cache: 3×(5+6+7+...+504) = {3 * sum(range(5, 505)):,} ops
    With cache:    3×5 + 3×500        = {3*5 + 3*500:,} ops
    Savings: {(1 - (3*5 + 3*500) / (3 * sum(range(5, 505)))):.1%} less computation
""")


# =====================================================================
# What the cache actually looks like in memory
# =====================================================================

print(f"{'─'*70}")
print(f"  WHAT THE CACHE LOOKS LIKE IN MEMORY")
print(f"{'─'*70}")

K_cached, V_cached = kv_cache
print(f"\n  After generating 3 tokens (8 total in sequence):")
print(f"\n  K cache shape: {list(K_cached.shape)}")
print(f"  = [batch=1, heads={n_heads}, tokens={K_cached.shape[2]}, d_k={d_k}]")
print(f"\n  V cache shape: {list(V_cached.shape)}")
print(f"  = [batch=1, heads={n_heads}, tokens={V_cached.shape[2]}, d_k={d_k}]")

total_cached_numbers = K_cached.numel() + V_cached.numel()
print(f"\n  Total numbers in cache: {total_cached_numbers:,}")
print(f"  = 2 (K+V) × {n_heads} heads × {K_cached.shape[2]} tokens × {d_k} d_k")

print(f"""
  In a real model (e.g., Llama 7B):
    d_model = 4096, n_heads = 32, d_k = 128, n_layers = 32

    For a 2048-token sequence, the KV cache holds:
    2 × 32 heads × 2048 tokens × 128 d_k × 32 layers × 2 bytes (fp16)
    = ~1 GB of GPU memory just for the cache

  This is why long conversations are expensive — the KV cache grows
  with every token generated and must stay in fast GPU memory.
  It's also why context length limits exist — the cache eventually
  exceeds available memory.
""")


# =====================================================================
# Important subtlety: Q is NOT cached
# =====================================================================

print(f"{'─'*70}")
print(f"  WHY NOT CACHE Q?")
print(f"{'─'*70}")
print(f"""
  Only K and V are cached. Q is not. Here's why:

  At each generation step, we need:
    scores = Q_new @ K_all.T     (new token's Q against ALL keys)
    output = attn @ V_all        (attention weights against ALL values)

  Q_new:  we only need the NEW token's query — it asks "what should I
          attend to?" Only the new token needs to ask this question.
          Previous tokens' queries were used in their own steps and
          are no longer needed.

  K_all:  we need ALL tokens' keys — the new token might want to
          attend to any previous token. Old keys don't change, so
          we cache them.

  V_all:  we need ALL tokens' values — once attention decides who
          to attend to, it needs to read their content. Old values
          don't change, so we cache them.

  The asymmetry:
    Q is consumed once (by the token that generated it)
    K and V are consumed repeatedly (by every future token)
""")
transformer_moe.py Mixture of Experts implementation + training
Python 548 lines
"""
Mixture of Experts — The MoE Variant of Our Transformer

This adds MoE to the same architecture from transformer_from_scratch.py.
The ONLY change: the single FFN in each layer is replaced by multiple
expert FFNs + a learned router. Everything else is identical.

Usage:
    pip install torch
    python transformer_moe.py

Compares a standard dense transformer to an MoE variant on the same
training data, showing how MoE achieves similar quality with more
parameters but less compute per token.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random


# =============================================================================
# STANDARD COMPONENTS (same as base transformer)
# =============================================================================

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_out(out)


class FeedForward(nn.Module):
    """Standard single FFN — same as before."""
    def __init__(self, d_model):
        super().__init__()
        self.up = nn.Linear(d_model, 4 * d_model)
        self.down = nn.Linear(4 * d_model, d_model)

    def forward(self, x):
        return self.down(F.gelu(self.up(x)))


# =============================================================================
# THE MoE LAYER — this is the new part
# =============================================================================

class MoELayer(nn.Module):
    """
    Mixture of Experts: replaces a single FFN with multiple expert FFNs
    and a learned router.

    Architecture:
        1. Router: a single linear layer that scores each expert for each token
        2. Top-K selection: pick the K highest-scoring experts
        3. Expert computation: run only the selected experts
        4. Weighted combination: blend expert outputs using router scores

    Parameters:
        d_model:     dimension of the residual stream (e.g., 64)
        n_experts:   total number of expert FFNs (e.g., 8)
        top_k:       how many experts each token uses (e.g., 2)
    """
    def __init__(self, d_model: int, n_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k

        # THE ROUTER: a single linear layer
        # Input: token vector [d_model]
        # Output: one score per expert [n_experts]
        # This is the ONLY new learned component. Everything else is just
        # multiple copies of the same FFN architecture.
        self.router = nn.Linear(d_model, n_experts, bias=False)

        # THE EXPERTS: n_experts independent FFNs
        # Each one is identical in architecture to the standard FFN,
        # but they have different learned weights.
        self.experts = nn.ModuleList([
            FeedForward(d_model) for _ in range(n_experts)
        ])

    def forward(self, x: torch.Tensor):
        """
        x: [batch, seq_len, d_model]
        returns: (output, aux_loss)
            output:   [batch, seq_len, d_model] — the blended expert outputs
            aux_loss: scalar — load balancing loss to keep experts evenly used
        """
        B, T, D = x.shape

        # ---- STEP 1: ROUTER SCORES ----
        # Each token gets a score for each expert
        router_logits = self.router(x)              # [B, T, n_experts]
        router_probs = F.softmax(router_logits, dim=-1)  # normalize to probabilities

        # ---- STEP 2: TOP-K SELECTION ----
        # Pick the top_k experts with highest scores for each token
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        # top_k_probs:   [B, T, top_k] — the weights for combining
        # top_k_indices: [B, T, top_k] — which experts were selected

        # Normalize the top-k probabilities so they sum to 1
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # ---- STEP 3 & 4: RUN SELECTED EXPERTS AND COMBINE ----
        # For each token, run only its selected experts and blend the results.
        #
        # In production, this is heavily optimized with custom CUDA kernels.
        # Here we use a simple loop for clarity.
        output = torch.zeros_like(x)

        for k in range(self.top_k):
            # Which expert does each token use for this slot?
            expert_indices = top_k_indices[:, :, k]  # [B, T]
            weights = top_k_probs[:, :, k]            # [B, T]

            # Run each expert on the tokens assigned to it
            for e_idx in range(self.n_experts):
                # Find which tokens selected this expert in this slot
                token_mask = (expert_indices == e_idx)  # [B, T] boolean

                if token_mask.any():
                    # Get those tokens, run through expert, weight by router score
                    expert_input = x[token_mask]  # [num_selected, D]
                    expert_output = self.experts[e_idx](expert_input)  # [num_selected, D]
                    expert_weights = weights[token_mask].unsqueeze(-1)  # [num_selected, 1]

                    output[token_mask] += expert_weights * expert_output

        # ---- LOAD BALANCING LOSS ----
        # Without this, the router collapses to always picking the same 1-2 experts.
        # We penalize uneven distribution of tokens across experts.
        #
        # For each expert: what fraction of tokens were routed to it?
        # Ideal: each expert gets 1/n_experts of the tokens.
        # Actual: compute from router probabilities.

        # Fraction of tokens assigned to each expert (from hard top-k assignment)
        # Using a soft approximation for differentiability
        expert_usage = router_probs.mean(dim=(0, 1))  # [n_experts]

        # Ideal uniform distribution
        target = torch.ones_like(expert_usage) / self.n_experts

        # Auxiliary loss: penalize deviation from uniform
        aux_loss = self.n_experts * (expert_usage * target).sum()
        # (Scaled so the loss is O(1) regardless of n_experts)

        return output, aux_loss


# =============================================================================
# TRANSFORMER BLOCKS: STANDARD vs MoE
# =============================================================================

class StandardBlock(nn.Module):
    """Standard transformer block with single FFN."""
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x, torch.tensor(0.0)  # no aux loss


class MoEBlock(nn.Module):
    """
    MoE transformer block — identical to StandardBlock except
    the FFN is replaced by the MoE layer.

    Attention is unchanged. Residual connections are unchanged.
    The only swap: single FFN → router + multiple expert FFNs.
    """
    def __init__(self, d_model, n_heads, n_experts=8, top_k=2):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.moe = MoELayer(d_model, n_experts, top_k)  # ← THE ONLY CHANGE

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        moe_out, aux_loss = self.moe(self.ln2(x))
        x = x + moe_out                               # residual, same as before
        return x, aux_loss


# =============================================================================
# FULL MODELS
# =============================================================================

class Transformer(nn.Module):
    """Configurable transformer: can use standard or MoE blocks."""
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4,
                 max_seq_len=32, use_moe=False, n_experts=8, top_k=2):
        super().__init__()
        self.use_moe = use_moe
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        if use_moe:
            self.blocks = nn.ModuleList([
                MoEBlock(d_model, n_heads, n_experts, top_k)
                for _ in range(n_layers)
            ])
        else:
            self.blocks = nn.ModuleList([
                StandardBlock(d_model, n_heads)
                for _ in range(n_layers)
            ])

        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))

        total_aux_loss = torch.tensor(0.0, device=idx.device)
        for block in self.blocks:
            x, aux_loss = block(x)
            total_aux_loss = total_aux_loss + aux_loss

        return self.lm_head(self.ln_f(x)), total_aux_loss


# =============================================================================
# TOKENIZER + DATA (same as training file)
# =============================================================================

class SimpleTokenizer:
    def __init__(self, vocab):
        self.word_to_id = {w: i for i, w in enumerate(vocab)}
        self.id_to_word = {i: w for i, w in enumerate(vocab)}
        self.vocab_size = len(vocab)

    def encode(self, text):
        return [self.word_to_id[w] for w in text.lower().split() if w in self.word_to_id]

    def decode(self, ids):
        return " ".join(self.id_to_word[i] for i in ids)


CORPUS = [
    "the cat sat on the mat", "the dog sat on the rug",
    "the cat slept on the bed", "the dog ran to the house",
    "a big cat sat on a big mat", "a small dog ran to a red house",
    "the quick brown fox jumped over the lazy dog",
    "the lazy cat slept under the big tree",
    "a happy bird sat on the tree", "the red bird flew over the house",
    "the brown dog slept on the mat", "a quick cat ran under the tree",
    "the happy dog jumped over the mat", "a lazy bird sat on the red house",
    "the small cat ran to the big tree", "a brown fox slept under the mat",
    "the big dog sat on a red rug", "a happy cat jumped over the lazy dog",
    "the quick bird flew to the sun", "a small fox ran under the big house",
]


def build_vocab(corpus):
    words = set()
    for s in corpus:
        for w in s.lower().split():
            words.add(w)
    return sorted(words)


def make_pairs(corpus, tokenizer, seq_len=8):
    inputs, targets = [], []
    for s in corpus:
        ids = tokenizer.encode(s)
        for start in range(0, len(ids) - 1, seq_len):
            end = min(start + seq_len, len(ids) - 1)
            if end - start < 2:
                continue
            inputs.append(ids[start:end])
            targets.append(ids[start+1:end+1])
    return inputs, targets


def collate(inputs, targets, batch_size):
    indices = list(range(len(inputs)))
    random.shuffle(indices)
    batches = []
    for i in range(0, len(indices), batch_size):
        batch_idx = indices[i:i+batch_size]
        b_inp = [inputs[j] for j in batch_idx]
        b_tgt = [targets[j] for j in batch_idx]
        max_len = max(len(s) for s in b_inp)
        padded_inp = [s + [0]*(max_len-len(s)) for s in b_inp]
        padded_tgt = [s + [0]*(max_len-len(s)) for s in b_tgt]
        batches.append((torch.tensor(padded_inp), torch.tensor(padded_tgt)))
    return batches


# =============================================================================
# TRAINING
# =============================================================================

def train_model(model, tokenizer, corpus, n_epochs=150, lr=3e-3,
                batch_size=8, aux_loss_weight=0.01, label="Model"):
    """Train a model and return loss history."""
    inputs, targets = make_pairs(corpus, tokenizer)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    loss_history = []
    for epoch in range(n_epochs):
        batches = collate(inputs, targets, batch_size)
        epoch_loss = 0.0

        for inp, tgt in batches:
            model.train()
            logits, aux_loss = model(inp)

            B, T, V = logits.shape
            # Main loss: cross-entropy on next-token prediction
            main_loss = F.cross_entropy(logits.view(B*T, V), tgt.view(B*T))

            # Total loss: main + weighted auxiliary (load balancing) loss
            loss = main_loss + aux_loss_weight * aux_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += main_loss.item()

        avg = epoch_loss / len(batches)
        loss_history.append(avg)

        if epoch == 0 or epoch == 9 or epoch % 25 == 24 or epoch == n_epochs - 1:
            print(f"  [{label}] Epoch {epoch+1:4d}/{n_epochs}  loss={avg:.4f}")

    return loss_history


@torch.no_grad()
def show_predictions(model, tokenizer, prompt, label=""):
    model.eval()
    ids = tokenizer.encode(prompt)
    logits, _ = model(torch.tensor([ids]))
    probs = F.softmax(logits[0, -1, :], dim=-1)
    top_probs, top_ids = torch.topk(probs, 5)
    print(f"  {label}\"{prompt}\" → next token:")
    for p, idx in zip(top_probs, top_ids):
        w = tokenizer.id_to_word[idx.item()]
        bar = "█" * int(p.item() * 40)
        print(f"    {p.item():6.1%}  {w:10s} {bar}")
    print()


@torch.no_grad()
def show_expert_routing(model, tokenizer, sentence):
    """Show which experts each token in a sentence gets routed to."""
    model.eval()
    ids = tokenizer.encode(sentence)
    words = sentence.lower().split()
    x = model.token_emb(torch.tensor([ids])) + model.pos_emb(torch.arange(len(ids)))
    x = x.unsqueeze(0) if x.dim() == 2 else x

    print(f"  Expert routing for: \"{sentence}\"")
    print(f"  {'Token':12s} {'Layer 1':18s} {'Layer 2':18s} {'Layer 3':18s} {'Layer 4':18s}")
    print(f"  {'─'*12} {'─'*18} {'─'*18} {'─'*18} {'─'*18}")

    # Run through each layer, tracking routing
    for block in model.blocks:
        x_norm = block.ln1(x)
        x = x + block.attn(x_norm)
        x_norm2 = block.ln2(x)

        # Get router decisions
        router_logits = block.moe.router(x_norm2)
        router_probs = F.softmax(router_logits, dim=-1)
        top_probs, top_idx = torch.topk(router_probs, block.moe.top_k, dim=-1)

        # Store for display
        if not hasattr(block, '_routing_info'):
            block._routing_info = []
        block._routing_info = list(zip(
            top_idx[0].tolist(),
            top_probs[0].tolist()
        ))

        # Actually run the MoE to advance x
        moe_out, _ = block.moe(x_norm2)
        x = x + moe_out

    # Print routing table
    for t in range(len(words)):
        line = f"  {words[t]:12s}"
        for block in model.blocks:
            indices, probs = block._routing_info[t]
            e1, e2 = indices
            p1, p2 = probs
            line += f" E{e1+1}({p1:.0%})+E{e2+1}({p2:.0%})  "
        print(line)

    # Clean up
    for block in model.blocks:
        if hasattr(block, '_routing_info'):
            del block._routing_info
    print()


# =============================================================================
# MAIN: Compare Standard vs MoE
# =============================================================================

if __name__ == "__main__":
    print("=" * 64)
    print("  STANDARD TRANSFORMER vs MIXTURE OF EXPERTS")
    print("=" * 64)

    vocab = build_vocab(CORPUS)
    tokenizer = SimpleTokenizer(vocab)
    print(f"\n  Vocabulary: {tokenizer.vocab_size} words")
    print(f"  Corpus: {len(CORPUS)} sentences\n")

    # Build both models
    d_model = 64
    n_heads = 4
    n_layers = 4

    standard = Transformer(
        tokenizer.vocab_size, d_model, n_heads, n_layers,
        use_moe=False,
    )

    moe = Transformer(
        tokenizer.vocab_size, d_model, n_heads, n_layers,
        use_moe=True, n_experts=8, top_k=2,
    )

    # Count parameters
    std_params = sum(p.numel() for p in standard.parameters())
    moe_params = sum(p.numel() for p in moe.parameters())
    moe_ffn_params = sum(p.numel() for n, p in moe.named_parameters() if "experts" in n)
    moe_router_params = sum(p.numel() for n, p in moe.named_parameters() if "router" in n)

    print(f"{'─'*64}")
    print(f"  PARAMETER COMPARISON")
    print(f"{'─'*64}")
    print(f"  Standard:  {std_params:>8,} total parameters")
    print(f"  MoE:       {moe_params:>8,} total parameters ({moe_params/std_params:.1f}× more)")
    print(f"    ├─ Expert FFNs: {moe_ffn_params:,} (8 experts × standard FFN size)")
    print(f"    └─ Routers:     {moe_router_params:,} (tiny — just one linear layer per block)")
    print(f"\n  But each token only activates 2/8 experts = ~{2/8:.0%} of FFN parameters")
    print(f"  So compute per token is much closer to Standard than the param count suggests.\n")

    # Train both
    print(f"{'─'*64}")
    print(f"  TRAINING: Standard Transformer")
    print(f"{'─'*64}")
    std_history = train_model(standard, tokenizer, CORPUS, n_epochs=150, label="Standard")

    print(f"\n{'─'*64}")
    print(f"  TRAINING: MoE Transformer (8 experts, top-2)")
    print(f"{'─'*64}")
    moe_history = train_model(moe, tokenizer, CORPUS, n_epochs=150, label="MoE")

    # Compare loss
    print(f"\n{'─'*64}")
    print(f"  LOSS COMPARISON")
    print(f"{'─'*64}")
    print(f"  Standard:  start={std_history[0]:.4f}  end={std_history[-1]:.4f}")
    print(f"  MoE:       start={moe_history[0]:.4f}  end={moe_history[-1]:.4f}")

    # ASCII loss curve comparison
    print(f"\n  Loss curves (S=Standard, M=MoE):")
    max_loss = max(max(std_history), max(moe_history))
    for row in range(8, -1, -1):
        threshold = max_loss * row / 8
        line = f"  {threshold:5.2f} │"
        for i in range(0, 150, 3):
            s = std_history[i] >= threshold
            m = moe_history[i] >= threshold
            if s and m:
                line += "▓"
            elif s:
                line += "S"
            elif m:
                line += "M"
            else:
                line += " "
        print(line)
    print(f"        └{'─' * 50}")
    print(f"         S=Standard  M=MoE  ▓=both")

    # Compare predictions
    print(f"\n{'─'*64}")
    print(f"  PREDICTION COMPARISON")
    print(f"{'─'*64}")
    for prompt in ["the cat sat on", "the dog ran to", "a big cat sat on"]:
        show_predictions(standard, tokenizer, prompt, label="[Standard] ")
        show_predictions(moe, tokenizer, prompt, label="[MoE     ] ")

    # Show expert routing (MoE only)
    print(f"{'─'*64}")
    print(f"  EXPERT ROUTING (which experts handle which tokens)")
    print(f"{'─'*64}")
    for sentence in ["the cat sat on the mat",
                      "the quick brown fox jumped over the lazy dog"]:
        show_expert_routing(moe, tokenizer, sentence)

    # Summary
    print(f"{'─'*64}")
    print(f"  SUMMARY")
    print(f"{'─'*64}")
    print(f"""
    Standard Transformer        MoE Transformer
    ─────────────────           ─────────────────
    {std_params:,} params            {moe_params:,} params ({moe_params/std_params:.1f}×)
    1 FFN per layer             8 FFNs per layer (2 active)
    All tokens, same path       Different tokens, different paths
    Final loss: {std_history[-1]:.4f}         Final loss: {moe_history[-1]:.4f}

    The MoE model has {moe_params/std_params:.1f}× more parameters but each token
    only uses ~25% of the FFN weights. More knowledge capacity,
    similar inference cost.

    Look at the routing table above — different tokens naturally
    get sent to different experts. This specialization emerged
    from training, not from any explicit programming.
    """)
model-scaling.html Where do the parameters come from? Interactive explorer
Interactive
streams/ Animated idea streams — the React visualizations that fed this lesson
Interactive

Addendum: source code

←→ 1/20