A complete walkthrough of how data flows through a transformer — from raw text to predictions during inference, and from predictions back to weight updates during training. Every tensor shape, every class, every line of code.
The transformer is built from nested Python classes. Each class is an nn.Module — PyTorch's base class for anything that participates in the computation graph. The contract is simple: define __init__ (store your weights and sub-components) and forward (describe how data flows through).
The composition is what matters. These aren't separate things — they're nested like Russian dolls. When you call model(input_ids), it triggers a cascade: the Transformer calls each TransformerBlock, each TransformerBlock calls CausalSelfAttention and FeedForward, and CausalSelfAttention calls its four Linear layers internally.
Transformer ├── token_emb (nn.Embedding) ├── pos_emb (nn.Embedding) ├── blocks (4 × TransformerBlock) │ └── TransformerBlock │ ├── ln1 (LayerNorm) │ ├── attn (CausalSelfAttention) │ │ ├── W_q, W_k, W_v, W_out (nn.Linear) │ ├── ln2 (LayerNorm) │ └── ffn (FeedForward) │ ├── up (nn.Linear 64→256) │ └── down (nn.Linear 256→64) ├── ln_f (LayerNorm) └── lm_head (nn.Linear 64→vocab)
That tree is the architecture. Every leaf node contains learnable weights. Training updates all of them. Inference uses all of them. Let's trace a complete forward pass through the code.
Training wraps the forward pass in a learning loop. The forward pass itself is just inference — the same Transformer.forward() runs. The difference is what happens after: measure the error, compute gradients, and nudge every weight. That cycle, repeated hundreds or thousands of times, transforms random noise into a model that understands language.
# Setup (once) vocab = build_vocab(CORPUS) # → 24 unique words tokenizer = SimpleTokenizer(vocab) # word ↔ integer lookup model = Transformer(vocab_size=24, ...) # random weights inputs, targets = make_training_pairs(corpus) # shifted by 1 optimizer = Adam(model.parameters(), lr=0.003) # tracks all weights # Loop (750 times) for epoch in range(150): batches = collate_batch(inputs, targets, batch_size=8) for input_ids, target_ids in batches: logits = model(input_ids) # forward pass loss = F.cross_entropy(logits, targets) # measure error optimizer.zero_grad() # clear old gradients loss.backward() # compute gradients optimizer.step() # update all weights
Built through conversation — one question at a time, one token at a time.
All source code is embedded below — click to expand, copy to clipboard.
Python files are runnable: pip install torch and execute.
"""
A Minimal Transformer — From Text to Next Token
This implements every step from our visualization in real PyTorch code.
Not production code — intentionally simple and heavily commented so you
can map each function to the conceptual step.
Steps:
1. Tokenize — text → integer IDs
2. Embed — integer IDs → learned vectors
3. Positional — add position information
4. Attention — Q·Kᵀ/√d → softmax → ·V
5. FFN — expand → activate → compress
6. Residual — x + f(x), then normalize
7. Stack — repeat for N layers
8. Logits — project back to vocabulary
9. Sample — probability distribution → pick one token
10. Loop — feed it back, repeat
Run this file directly to see a full forward pass with prints at every stage:
python transformer_from_scratch.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# =============================================================================
# STEP 1: TOKENIZE
# =============================================================================
# In production, this would be BPE (byte-pair encoding) — a learned algorithm
# that splits text into subword chunks. Here we use a dead-simple word-level
# tokenizer so you can see exactly what's happening.
class SimpleTokenizer:
"""
Maps words to integers and back. That's ALL tokenization is.
No neural network, no learning — just a lookup table.
"""
def __init__(self, vocab: list[str]):
self.word_to_id = {word: i for i, word in enumerate(vocab)}
self.id_to_word = {i: word for i, word in enumerate(vocab)}
self.vocab_size = len(vocab)
def encode(self, text: str) -> list[int]:
"""Text in, integers out. Purely mechanical."""
return [self.word_to_id[w] for w in text.lower().split()]
def decode(self, ids: list[int]) -> str:
"""Integers in, text out."""
return " ".join(self.id_to_word[i] for i in ids)
# =============================================================================
# STEP 2 & 3: EMBED + POSITIONAL ENCODING
# =============================================================================
class Embedding(nn.Module):
"""
Step 2: Token embedding — each integer ID looks up a ROW in a learned matrix.
Token 4521 → row 4521 → a d_model-dimensional vector.
This is where MEANING enters. These vectors are learned during training
so that semantically similar tokens end up near each other.
Step 3: Positional encoding — without this, "dog bites man" and "man bites dog"
would produce identical representations (addition is commutative).
We add a unique position signal so the model knows token ORDER.
"""
def __init__(self, vocab_size: int, d_model: int, max_seq_len: int):
super().__init__()
# Step 2: The embedding matrix — shape [vocab_size, d_model]
# Each of the ~100k vocabulary entries gets its own learned vector
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Step 3: Positional encoding — also a learned lookup table
# Position 0 gets one vector, position 1 gets another, etc.
self.position_embedding = nn.Embedding(max_seq_len, d_model)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
token_ids: shape [batch_size, seq_len] — the integers from tokenization
returns: shape [batch_size, seq_len, d_model] — rich vectors with position info
"""
seq_len = token_ids.shape[1]
# Step 2: Look up token vectors
tok_emb = self.token_embedding(token_ids) # [batch, seq_len, d_model]
# Step 3: Look up position vectors
positions = torch.arange(seq_len, device=token_ids.device) # [0, 1, 2, ...]
pos_emb = self.position_embedding(positions) # [seq_len, d_model]
# Add them together — the model gets BOTH "what token" and "what position"
return tok_emb + pos_emb
# =============================================================================
# STEP 4: SELF-ATTENTION
# =============================================================================
class CausalSelfAttention(nn.Module):
"""
The core mechanism: each token decides what to "pay attention to"
in the sequence so far.
Three learned projections turn each token's vector into:
Q (query) — "what am I looking for?"
K (key) — "what do I contain?"
V (value) — "what information do I provide if attended to?"
Then: scores = Q · Kᵀ / √d_k → softmax → weighted sum of V
Multi-head: we do this multiple times in parallel with different
learned projections, then concatenate. Each "head" can learn to
attend to different things (syntax, coreference, semantics...).
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_k = d_model // n_heads # dimension per head
# Three separate learned linear projections
self.W_q = nn.Linear(d_model, d_model) # projects to all Q heads at once
self.W_k = nn.Linear(d_model, d_model) # projects to all K heads at once
self.W_v = nn.Linear(d_model, d_model) # projects to all V heads at once
# Final projection after concatenating all heads
self.W_out = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: shape [batch, seq_len, d_model]
returns: shape [batch, seq_len, d_model]
"""
B, T, C = x.shape # batch, sequence length, d_model
# Project into Q, K, V — then reshape for multi-head
# Each goes from [batch, seq_len, d_model] → [batch, n_heads, seq_len, d_k]
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# ---- THE CORE MATH OF ATTENTION ----
# scores = Q · Kᵀ / √d_k
# Shape: [batch, n_heads, seq_len, seq_len]
# Each entry (i, j) = "how much should token i attend to token j?"
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
# CAUSAL MASK: token i can only attend to tokens 0..i (not the future!)
# This is what makes it autoregressive — you can't cheat by looking ahead
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf")) # future → -inf → 0 after softmax
# Softmax normalizes scores to probabilities (each row sums to 1)
attn_weights = F.softmax(scores, dim=-1) # [batch, n_heads, seq_len, seq_len]
# Weighted sum of values — this IS the attention output
# Each token's new representation is a mix of all (past) tokens' values,
# weighted by how relevant they are
attn_output = attn_weights @ V # [batch, n_heads, seq_len, d_k]
# Concatenate all heads back together and project
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(attn_output)
# =============================================================================
# STEP 5: FEED-FORWARD NETWORK
# =============================================================================
class FeedForward(nn.Module):
"""
Two linear transforms with an activation in between.
Looks simple, but this is where ~2/3 of all parameters live.
The expansion to 4× the model dimension creates a massive "bottleneck"
that the model uses to store and retrieve learned knowledge.
d_model → 4*d_model → GELU → 4*d_model → d_model
(4096) (16384) (16384) (4096)
"""
def __init__(self, d_model: int):
super().__init__()
d_ff = 4 * d_model # standard expansion factor
self.up = nn.Linear(d_model, d_ff) # expand
self.down = nn.Linear(d_ff, d_model) # compress back
def forward(self, x: torch.Tensor) -> torch.Tensor:
# up-project, activate (GELU — smooth ReLU), down-project
return self.down(F.gelu(self.up(x)))
# =============================================================================
# STEP 6 & 7: TRANSFORMER BLOCK (with residual connections) × N LAYERS
# =============================================================================
class TransformerBlock(nn.Module):
"""
One complete transformer layer. This is where Steps 4-6 come together:
┌─────────────────────────────────┐
│ x ─────────────────────┐ │
│ │ │ │
│ ├→ LayerNorm → Attention ─→ + │ ← residual: x + attention(norm(x))
│ │ │ │
│ ├─────────────────────┐ │ │
│ │ │ │ │
│ ├→ LayerNorm → FFN ────→ + │ ← residual: x + ffn(norm(x))
│ │ │
└─────────────────────────────────┘
Pre-norm architecture (what GPT uses):
- LayerNorm BEFORE each sub-block
- Residual addition AFTER each sub-block
- The residual stream itself stays unnormalized and free-flowing
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.ln1 = nn.LayerNorm(d_model) # normalize before attention
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model) # normalize before FFN
self.ffn = FeedForward(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# STEP 6 IN ACTION:
# The block computes a transformation, then ADDS it to the input.
# The block only gets to NUDGE the stream, never replace it.
x = x + self.attn(self.ln1(x)) # residual around attention
x = x + self.ffn(self.ln2(x)) # residual around FFN
return x
# That's it. x = x + f(x). The "+" is the entire trick.
# =============================================================================
# PUTTING IT ALL TOGETHER: THE COMPLETE TRANSFORMER
# =============================================================================
class Transformer(nn.Module):
"""
The full model: Embed → N × (Attention + FFN with residuals) → Logits
This is a decoder-only transformer (like GPT / Claude).
"""
def __init__(
self,
vocab_size: int,
d_model: int = 64, # small for demo — real models use 4096+
n_heads: int = 4,
n_layers: int = 4, # small for demo — real models use 32-80+
max_seq_len: int = 128,
):
super().__init__()
# Steps 2-3: embedding + positional encoding
self.embedding = Embedding(vocab_size, d_model, max_seq_len)
# Steps 4-7: stack of transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads) for _ in range(n_layers)
])
# Final layer norm (stabilizes the output)
self.ln_final = nn.LayerNorm(d_model)
# Step 8: project from d_model back to vocab_size to get logits
# This is literally a matrix multiply: [d_model] → [vocab_size]
# One score per word in the entire vocabulary
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Full forward pass: integers in, logits out.
token_ids: [batch, seq_len] — the token integers
returns: [batch, seq_len, vocab_size] — raw logits (NOT probabilities)
"""
# Steps 2-3: embed tokens and add position information
x = self.embedding(token_ids)
# Steps 4-7: pass through every transformer layer
# Each layer reads the residual stream and adds its contribution
for block in self.blocks:
x = block(x)
# Normalize the final residual stream
x = self.ln_final(x)
# Step 8: project to vocabulary — one logit per vocab word
logits = self.lm_head(x) # [batch, seq_len, vocab_size]
return logits
# =============================================================================
# STEP 9: SAMPLING — logits → probability distribution → pick a token
# =============================================================================
def sample_next_token(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
) -> int:
"""
Takes raw logits for the LAST token position and samples a next token.
logits: [vocab_size] — one raw score per vocabulary entry
temperature: controls randomness (0 = greedy, 1 = faithful, >1 = creative)
top_k: only consider the top K highest-probability tokens
top_p: only consider tokens whose cumulative probability ≤ p
Returns: a single integer — the sampled token ID
"""
# --- Temperature ---
# Divide logits by temperature BEFORE softmax.
# Low temp → differences are amplified → sharp distribution → nearly deterministic
# High temp → differences are compressed → flat distribution → more random
if temperature != 1.0:
logits = logits / temperature
# --- Top-K filtering ---
# Zero out everything except the K most likely tokens
if top_k > 0:
top_k_values, _ = torch.topk(logits, top_k)
min_top_k = top_k_values[-1]
logits[logits < min_top_k] = float("-inf")
# --- Top-P (nucleus) filtering ---
# Keep the smallest set of tokens whose cumulative probability ≥ p
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
remove_mask = cumulative_probs > top_p
remove_mask[1:] = remove_mask[:-1].clone() # keep at least one token
remove_mask[0] = False
sorted_logits[remove_mask] = float("-inf")
# Scatter back to original order
logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)
# --- Convert to probabilities and sample ---
probs = F.softmax(logits, dim=-1) # logits → probabilities (sum to 1)
token_id = torch.multinomial(probs, 1).item() # random draw from distribution
return token_id
# =============================================================================
# STEP 10: AUTOREGRESSIVE LOOP — generate token by token
# =============================================================================
@torch.no_grad() # no gradients needed during inference
def generate(
model: Transformer,
tokenizer: SimpleTokenizer,
prompt: str,
max_new_tokens: int = 10,
temperature: float = 1.0,
top_k: int = 10,
) -> str:
"""
The full generation loop:
1. Tokenize the prompt
2. Run forward pass to get logits
3. Sample next token from logits
4. Append token to sequence
5. Repeat from step 2
This is AUTOREGRESSIVE generation — each new token depends on all previous ones.
"""
model.eval()
token_ids = tokenizer.encode(prompt)
token_ids = torch.tensor([token_ids]) # add batch dimension: [1, seq_len]
print(f"\n{'='*60}")
print(f"GENERATING from prompt: \"{prompt}\"")
print(f"{'='*60}\n")
for step in range(max_new_tokens):
# FULL FORWARD PASS — every step reruns the entire model
# (In production, KV-caching avoids redundant computation)
logits = model(token_ids) # [1, seq_len, vocab_size]
# We only care about the LAST position's logits
# This is the residual stream vector at the final token,
# projected back to vocabulary space
last_logits = logits[0, -1, :] # [vocab_size]
# Sample from the distribution
next_id = sample_next_token(last_logits, temperature=temperature, top_k=top_k)
# Show what the model is "thinking"
probs = F.softmax(last_logits, dim=-1)
top_probs, top_ids = torch.topk(probs, 5)
top_words = [(tokenizer.id_to_word[i.item()], f"{p:.1%}") for i, p in zip(top_ids, top_probs)]
chosen_word = tokenizer.id_to_word[next_id]
print(f" Step {step+1}: top 5 = {top_words}")
print(f" sampled → \"{chosen_word}\" (id={next_id})\n")
# APPEND and loop — this is the autoregressive part
next_tensor = torch.tensor([[next_id]])
token_ids = torch.cat([token_ids, next_tensor], dim=1)
result = tokenizer.decode(token_ids[0].tolist())
print(f"{'='*60}")
print(f"RESULT: \"{result}\"")
print(f"{'='*60}")
return result
# =============================================================================
# DEMO: Run everything end-to-end
# =============================================================================
if __name__ == "__main__":
# Build a tiny vocabulary (real models have ~100k tokens)
vocab = [
"the", "cat", "sat", "on", "mat", "a", "dog", "ran", "to",
"big", "small", "red", "blue", "house", "tree", "is", "was",
"and", "in", "under", "happy", "lazy", "quick", "brown", "fox",
"jumped", "over", "slept", "ate", "fish", "bird", "sky", "sun",
]
tokenizer = SimpleTokenizer(vocab)
print(f"Vocabulary size: {tokenizer.vocab_size}")
# STEP 1: Tokenize
prompt = "the cat sat on"
token_ids = tokenizer.encode(prompt)
print(f"\n--- STEP 1: Tokenize ---")
print(f" \"{prompt}\" → {token_ids}")
print(f" These are arbitrary integers. 0='the', 1='cat', etc.")
# Build the model
model = Transformer(
vocab_size=tokenizer.vocab_size,
d_model=64, # tiny for demo
n_heads=4, # 4 attention heads, each 16-dimensional
n_layers=4, # 4 transformer blocks
)
# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"\nModel: {n_params:,} parameters")
print(f" (GPT-3 has 175,000,000,000 — same architecture, just bigger numbers)\n")
# STEPS 2-3: Embed
ids_tensor = torch.tensor([token_ids])
embeddings = model.embedding(ids_tensor)
print(f"--- STEPS 2-3: Embed + Position ---")
print(f" Input shape: {ids_tensor.shape} (batch=1, seq_len=4)")
print(f" Output shape: {embeddings.shape} (batch=1, seq_len=4, d_model=64)")
print(f" Each integer is now a 64-dimensional vector with position info.\n")
# STEPS 4-7: Forward through all layers
x = embeddings
for i, block in enumerate(model.blocks):
x_before = x.clone()
x = block(x)
# Show that residual connections preserve the original signal
delta = (x - x_before).abs().mean().item()
print(f"--- STEPS 4-7: Layer {i+1} ---")
print(f" Mean |delta| = {delta:.4f} (the 'nudge' this layer added)")
# STEP 8: Logits
x_norm = model.ln_final(x)
logits = model.lm_head(x_norm)
print(f"\n--- STEP 8: Logits ---")
print(f" Residual stream shape: {x_norm.shape} (batch=1, seq_len=4, d_model=64)")
print(f" Logits shape: {logits.shape} (batch=1, seq_len=4, vocab_size={tokenizer.vocab_size})")
print(f" That's one raw score per vocabulary word, per position.\n")
# STEP 9: Sample
last_logits = logits[0, -1, :] # logits at last position
probs = F.softmax(last_logits, dim=-1)
print(f"--- STEP 9: Sample ---")
print(f" Raw logits (first 5): {last_logits[:5].tolist()}")
print(f" After softmax (first 5): {probs[:5].tolist()}")
print(f" These are probabilities now — they sum to {probs.sum():.4f}\n")
# STEP 10: Full generation loop
generate(model, tokenizer, prompt="the cat sat on", max_new_tokens=6, temperature=0.8)
# Show temperature comparison
print(f"\n{'='*60}")
print("TEMPERATURE COMPARISON")
print(f"{'='*60}")
for temp in [0.1, 0.5, 1.0, 2.0]:
print(f"\n temp={temp}:", end=" ")
for _ in range(5):
ids = torch.tensor([tokenizer.encode("the cat sat on")])
model.eval()
with torch.no_grad():
for _ in range(4):
logits = model(ids)
next_id = sample_next_token(logits[0, -1, :].clone(), temperature=temp, top_k=10)
ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1)
print(f"\"{tokenizer.decode(ids[0].tolist())}\"", end=" ")
print()
print("\n Low temperature → repetitive / deterministic")
print(" High temperature → diverse / creative / chaotic")
"""
Training a Transformer — From Random Weights to Language Understanding
This implements every step from the TRAINING visualization in real PyTorch.
Run this file and watch a model go from random gibberish to actually
learning next-token prediction on a tiny dataset.
Training steps (mapped to the visualization):
1. Training Data — build a small text corpus
2. Input → Target Pairs — create (input, target) pairs where target = next token
3. Forward Pass — run the model, get logits at every position
4. Cross-Entropy Loss — measure how wrong the predictions are
5. Backpropagation — compute gradients for every weight
6. Gradient Descent — nudge every weight to reduce loss
7. The Training Loop — repeat steps 2-6 thousands of times
8. Before vs. After — compare random model to trained model
Usage:
pip install torch
python transformer_training.py
The model architecture is duplicated here (not imported) so this file is
completely self-contained. See transformer_from_scratch.py for detailed
comments on the architecture itself.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
# =============================================================================
# MODEL ARCHITECTURE (same as inference file, condensed)
# =============================================================================
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_out = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(out)
class FeedForward(nn.Module):
def __init__(self, d_model):
super().__init__()
self.up = nn.Linear(d_model, 4 * d_model)
self.down = nn.Linear(4 * d_model, d_model)
def forward(self, x):
return self.down(F.gelu(self.up(x)))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # residual around attention
x = x + self.ffn(self.ln2(x)) # residual around FFN
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4, max_seq_len=128):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx):
B, T = idx.shape
tok = self.token_emb(idx)
pos = self.pos_emb(torch.arange(T, device=idx.device))
x = tok + pos
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.lm_head(x) # [B, T, vocab_size]
# =============================================================================
# SIMPLE TOKENIZER (same as inference file)
# =============================================================================
class SimpleTokenizer:
def __init__(self, vocab):
self.word_to_id = {w: i for i, w in enumerate(vocab)}
self.id_to_word = {i: w for i, w in enumerate(vocab)}
self.vocab_size = len(vocab)
def encode(self, text):
return [self.word_to_id[w] for w in text.lower().split() if w in self.word_to_id]
def decode(self, ids):
return " ".join(self.id_to_word[i] for i in ids)
# =============================================================================
# STEP 1: TRAINING DATA
# =============================================================================
# In reality this would be terabytes of text. Here we use a tiny corpus
# that's small enough to memorize, so we can clearly see training work.
CORPUS = [
"the cat sat on the mat",
"the dog sat on the rug",
"the cat slept on the bed",
"the dog ran to the house",
"a big cat sat on a big mat",
"a small dog ran to a red house",
"the quick brown fox jumped over the lazy dog",
"the lazy cat slept under the big tree",
"a happy bird sat on the tree",
"the red bird flew over the house",
"the brown dog slept on the mat",
"a quick cat ran under the tree",
"the happy dog jumped over the mat",
"a lazy bird sat on the red house",
"the small cat ran to the big tree",
"a brown fox slept under the mat",
"the big dog sat on a red rug",
"a happy cat jumped over the lazy dog",
"the quick bird flew to the sun",
"a small fox ran under the big house",
]
# Build vocabulary from the corpus
def build_vocab(corpus):
words = set()
for sentence in corpus:
for word in sentence.lower().split():
words.add(word)
return sorted(words)
# =============================================================================
# STEP 2: INPUT → TARGET PAIRS
# =============================================================================
# The training signal: for every token in position i, the target is the
# token at position i+1. We're teaching the model to predict what comes next.
def make_training_pairs(corpus, tokenizer, seq_len=8):
"""
Convert corpus into (input, target) tensor pairs.
For the sentence "the cat sat on the mat":
input = [the, cat, sat, on, the]
target = [cat, sat, on, the, mat]
Every position is a training example, all computed simultaneously.
"""
all_inputs = []
all_targets = []
for sentence in corpus:
ids = tokenizer.encode(sentence)
# Slide a window across the sentence to create fixed-length chunks
for start in range(0, len(ids) - 1, seq_len):
end = min(start + seq_len, len(ids) - 1)
if end - start < 2: # need at least 2 tokens
continue
inp = ids[start:end] # tokens 0..N-1
tgt = ids[start + 1:end + 1] # tokens 1..N (shifted by 1)
all_inputs.append(inp)
all_targets.append(tgt)
return all_inputs, all_targets
def collate_batch(inputs, targets, batch_size):
"""Group training pairs into batches, padding to equal length."""
indices = list(range(len(inputs)))
random.shuffle(indices)
batches = []
for i in range(0, len(indices), batch_size):
batch_idx = indices[i:i + batch_size]
batch_inp = [inputs[j] for j in batch_idx]
batch_tgt = [targets[j] for j in batch_idx]
# Pad to the longest sequence in this batch
max_len = max(len(s) for s in batch_inp)
padded_inp = [s + [0] * (max_len - len(s)) for s in batch_inp]
padded_tgt = [s + [0] * (max_len - len(s)) for s in batch_tgt]
batches.append((
torch.tensor(padded_inp),
torch.tensor(padded_tgt),
))
return batches
# =============================================================================
# STEP 3-6: THE TRAINING STEP (forward, loss, backprop, update — all in one)
# =============================================================================
def train_step(model, optimizer, input_ids, target_ids, step_num, verbose=False):
"""
One complete training step. This is the heart of training:
Step 3: Forward pass — run the model, get logits at every position
Step 4: Loss — compare predictions to targets via cross-entropy
Step 5: Backprop — compute gradients for all weights
Step 6: Update — nudge all weights via optimizer
"""
model.train()
# ---- STEP 3: FORWARD PASS ----
# Identical to inference, but we compute logits at ALL positions
logits = model(input_ids) # [batch, seq_len, vocab_size]
# ---- STEP 4: CROSS-ENTROPY LOSS ----
# Reshape for PyTorch's loss function:
# logits: [batch * seq_len, vocab_size]
# targets: [batch * seq_len]
B, T, V = logits.shape
loss = F.cross_entropy(
logits.view(B * T, V), # predicted distribution at each position
target_ids.view(B * T), # actual next token at each position
)
# loss is a single number: the average "surprise" across all positions.
# High loss = model was wrong. Low loss = model predicted well.
# ---- STEP 5: BACKPROPAGATION ----
# This single call computes ∂loss/∂weight for EVERY weight in the model.
# It works backward through every layer using the chain rule.
optimizer.zero_grad() # clear old gradients
loss.backward() # compute new gradients — this IS backprop
# At this point, every parameter has a .grad attribute:
if verbose and step_num == 0:
print("\n [Backprop] Gradient samples from different layers:")
for name, param in model.named_parameters():
if param.grad is not None and ("attn.W_q" in name or "ffn.up" in name or "token_emb" in name):
g = param.grad
print(f" {name:40s} grad mean={g.mean():.6f} grad std={g.std():.6f}")
# ---- STEP 6: GRADIENT DESCENT ----
# The optimizer adjusts every weight: w = w - lr * gradient
# (Adam is fancier — it tracks momentum and adapts per-parameter)
optimizer.step()
return loss.item()
# =============================================================================
# STEP 7: THE TRAINING LOOP
# =============================================================================
def train(model, tokenizer, corpus, n_epochs=150, lr=3e-3, batch_size=8, seq_len=8):
"""
The complete training loop. Steps 2-6 repeat for many epochs.
An 'epoch' = one full pass through all the training data.
Real LLM training does ~1 epoch (they have so much data they don't repeat).
Our tiny dataset needs many epochs to learn patterns.
"""
# Step 2: Prepare all input→target pairs
inputs, targets = make_training_pairs(corpus, tokenizer, seq_len)
print(f" Training pairs: {len(inputs)}")
print(f" Example pair:")
print(f" input: {inputs[0]} → \"{tokenizer.decode(inputs[0])}\"")
print(f" target: {targets[0]} → \"{tokenizer.decode(targets[0])}\"")
print(f" (each position predicts the next token)\n")
# Step 6: Set up the optimizer (Adam — the standard for transformers)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Step 7: The loop itself
loss_history = []
for epoch in range(n_epochs):
batches = collate_batch(inputs, targets, batch_size)
epoch_loss = 0.0
n_batches = 0
for input_ids, target_ids in batches:
loss = train_step(
model, optimizer, input_ids, target_ids,
step_num=epoch * len(batches) + n_batches,
verbose=(epoch == 0 and n_batches == 0),
)
epoch_loss += loss
n_batches += 1
avg_loss = epoch_loss / n_batches
loss_history.append(avg_loss)
# Print progress at key milestones
if epoch == 0 or epoch == 4 or epoch == 9 or epoch % 25 == 24 or epoch == n_epochs - 1:
bar_len = int(max(0, (avg_loss / 4.0)) * 30)
bar = "█" * bar_len + "░" * (30 - bar_len)
print(f" Epoch {epoch+1:4d}/{n_epochs} loss={avg_loss:.4f} [{bar}]")
return loss_history
# =============================================================================
# STEP 8: BEFORE vs. AFTER — compare random model to trained model
# =============================================================================
@torch.no_grad()
def show_predictions(model, tokenizer, prompt, label="", top_k=5):
"""Show what the model predicts as the next token for a given prompt."""
model.eval()
ids = tokenizer.encode(prompt)
if not ids:
return
input_tensor = torch.tensor([ids])
logits = model(input_tensor)
last_logits = logits[0, -1, :] # logits at last position
probs = F.softmax(last_logits, dim=-1)
top_probs, top_ids = torch.topk(probs, min(top_k, tokenizer.vocab_size))
print(f" {label}\"{prompt}\" → next token predictions:")
for prob, idx in zip(top_probs, top_ids):
word = tokenizer.id_to_word[idx.item()]
bar = "█" * int(prob.item() * 40)
print(f" {prob.item():6.1%} {word:12s} {bar}")
print()
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_tokens=8, temperature=0.7):
"""Generate text autoregressively."""
model.eval()
ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([ids])
for _ in range(max_tokens):
logits = model(input_tensor)
last_logits = logits[0, -1, :] / temperature
probs = F.softmax(last_logits, dim=-1)
next_id = torch.multinomial(probs, 1).item()
input_tensor = torch.cat([input_tensor, torch.tensor([[next_id]])], dim=1)
return tokenizer.decode(input_tensor[0].tolist())
# =============================================================================
# MAIN: Run the full training pipeline
# =============================================================================
if __name__ == "__main__":
print("=" * 64)
print(" TRANSFORMER TRAINING — FROM RANDOM TO MEANINGFUL")
print("=" * 64)
# ---- STEP 1: Training Data ----
print(f"\n{'─'*64}")
print("STEP 1: Training Data")
print(f"{'─'*64}")
print(f" Corpus: {len(CORPUS)} sentences")
print(f" Sample: \"{CORPUS[0]}\"")
print(f" Sample: \"{CORPUS[3]}\"")
vocab = build_vocab(CORPUS)
tokenizer = SimpleTokenizer(vocab)
print(f" Vocabulary: {tokenizer.vocab_size} words → {vocab}")
# ---- Build model ----
model = Transformer(
vocab_size=tokenizer.vocab_size,
d_model=64,
n_heads=4,
n_layers=4,
max_seq_len=32,
)
n_params = sum(p.numel() for p in model.parameters())
print(f"\n Model: {n_params:,} parameters")
print(f" (Same architecture as GPT/Claude, just tiny numbers)")
# ---- STEP 8a: BEFORE training — random weights ----
print(f"\n{'─'*64}")
print("STEP 8a: BEFORE Training (random weights)")
print(f"{'─'*64}")
show_predictions(model, tokenizer, "the cat sat on", label="[RANDOM] ")
show_predictions(model, tokenizer, "the dog ran to", label="[RANDOM] ")
print(" Generated text from random model:")
for i in range(3):
text = generate_text(model, tokenizer, "the cat", max_tokens=6)
print(f" → \"{text}\"")
print(" (Pure gibberish — weights are random noise)\n")
# ---- STEPS 2-7: Train! ----
print(f"{'─'*64}")
print("STEPS 2-7: Training Loop")
print(f"{'─'*64}")
# Step 2
print("\n STEP 2: Creating input → target pairs...")
loss_history = train(model, tokenizer, CORPUS, n_epochs=150, lr=3e-3)
# ---- Print loss trajectory ----
print(f"\n Loss trajectory:")
print(f" Start: {loss_history[0]:.4f} (random guessing ≈ -log(1/{tokenizer.vocab_size}) = {math.log(tokenizer.vocab_size):.2f})")
print(f" End: {loss_history[-1]:.4f}")
print(f" Change: {loss_history[0] - loss_history[-1]:.4f} reduction")
# Show the loss curve as ASCII art
print(f"\n Loss curve:")
max_loss = max(loss_history)
for row in range(10, -1, -1):
threshold = max_loss * row / 10
line = " "
for i in range(0, len(loss_history), max(1, len(loss_history) // 60)):
line += "█" if loss_history[i] >= threshold else " "
print(f" {threshold:5.2f} |{line}")
print(f" └{'─' * 62}")
print(f" epoch 1{' ' * 52}epoch {len(loss_history)}")
# ---- STEP 8b: AFTER training ----
print(f"\n{'─'*64}")
print("STEP 8b: AFTER Training (learned weights)")
print(f"{'─'*64}")
show_predictions(model, tokenizer, "the cat sat on", label="[TRAINED] ")
show_predictions(model, tokenizer, "the dog ran to", label="[TRAINED] ")
show_predictions(model, tokenizer, "a big cat sat on", label="[TRAINED] ")
print(" Generated text from trained model:")
for i in range(5):
text = generate_text(model, tokenizer, "the cat", max_tokens=6, temperature=0.7)
print(f" → \"{text}\"")
print()
for i in range(5):
text = generate_text(model, tokenizer, "a small", max_tokens=6, temperature=0.7)
print(f" → \"{text}\"")
# ---- Side-by-side comparison ----
print(f"\n{'─'*64}")
print("SIDE-BY-SIDE: What Changed?")
print(f"{'─'*64}")
print("""
BEFORE (random) AFTER (trained)
───────────────── ─────────────────
Every word ≈ equal prob "mat" / "rug" are likely
No pattern recognition Knows "sat on" → surface
Generates gibberish Generates coherent phrases
Loss ≈ {:.2f} Loss ≈ {:.2f}
Same architecture. Same math. Same code.
The ONLY difference: the values stored in the weight matrices.
Training sculpted random numbers into understanding.
""".format(loss_history[0], loss_history[-1]))
# ---- Peek inside the weights ----
print(f"{'─'*64}")
print("BONUS: Peek Inside the Learned Weights")
print(f"{'─'*64}")
# Show that similar words ended up with similar embeddings
print("\n Embedding similarity (cosine) between words:")
def cos_sim(w1, w2):
id1 = tokenizer.word_to_id[w1]
id2 = tokenizer.word_to_id[w2]
e1 = model.token_emb.weight[id1]
e2 = model.token_emb.weight[id2]
return F.cosine_similarity(e1.unsqueeze(0), e2.unsqueeze(0)).item()
pairs = [("cat", "dog"), ("cat", "bird"), ("sat", "slept"),
("big", "small"), ("cat", "mat"), ("on", "under")]
for w1, w2 in pairs:
sim = cos_sim(w1, w2)
bar = "█" * int(max(0, (sim + 1) / 2 * 20))
print(f" {w1:8s} ↔ {w2:8s} sim={sim:+.3f} {bar}")
print("""
cat↔dog are similar (both animals in similar contexts)
sat↔slept are similar (both actions: "X ___ on the Y")
cat↔mat are less similar (different roles despite co-occurring)
These relationships EMERGED from training. Nobody programmed them.
The model discovered that similar words appear in similar contexts
and organized its embedding space accordingly.
""")
"""
Q, K, V Walkthrough — Grounded in Our Code
This traces through CausalSelfAttention from transformer_from_scratch.py
with real numbers, showing exactly what Q, K, V are and how they interact.
Our model config:
d_model = 64 (each token is a 64-dim vector)
n_heads = 4 (4 parallel attention heads)
d_k = 16 (64 / 4 = 16 dims per head)
Input sentence: "The cat sat on the"
pos0 pos1 pos2 pos3 pos4
Run with:
python qkv_walkthrough.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)
# =====================================================================
# Setup: Build the pieces from our inference code
# =====================================================================
vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
"a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}
d_model = 64
n_heads = 4
d_k = d_model // n_heads # 16
# Embedding (token + position) — same as our code
token_emb = nn.Embedding(len(vocab), d_model)
pos_emb = nn.Embedding(32, d_model)
# The attention layer — this is what we're tracing through
attn = nn.modules.linear.Linear # just for reference
W_q = nn.Linear(d_model, d_model) # 64→64, contains 64×64=4096 learned weights
W_k = nn.Linear(d_model, d_model) # 64→64, another 4096 learned weights
W_v = nn.Linear(d_model, d_model) # 64→64, another 4096 learned weights
W_out = nn.Linear(d_model, d_model)
# =====================================================================
# THE WALKTHROUGH
# =====================================================================
print("=" * 70)
print(" Q, K, V WALKTHROUGH")
print(" Tracing through CausalSelfAttention with real numbers")
print("=" * 70)
# ---- Step 0: Input tokens become vectors ----
sentence = "the cat sat on the"
token_ids = torch.tensor([[word_to_id[w] for w in sentence.split()]])
B, T = token_ids.shape # B=1 batch, T=5 tokens
print(f"\n{'─'*70}")
print(f"BEFORE ATTENTION: Input vectors")
print(f"{'─'*70}")
print(f" Sentence: \"{sentence}\"")
print(f" Token IDs: {token_ids[0].tolist()}")
# Embed tokens + positions
x = token_emb(token_ids) + pos_emb(torch.arange(T))
# x shape: [1, 5, 64] — 5 tokens, each a 64-dimensional vector
print(f" Input x shape: {list(x.shape)} — {T} tokens, each a {d_model}-dim vector")
print(f"\n Each token is now a vector of {d_model} numbers:")
for i, word in enumerate(sentence.split()):
vec = x[0, i]
print(f" pos {i} \"{word:4s}\" → [{vec[0]:.3f}, {vec[1]:.3f}, {vec[2]:.3f}, ... {vec[-1]:.3f}]")
# ---- Step 1: Project into Q, K, V ----
# This is the CRITICAL step. The same input x gets multiplied by
# three DIFFERENT learned weight matrices to produce three different
# sets of vectors.
print(f"\n{'─'*70}")
print(f"STEP 1: Project into Q, K, V")
print(f"{'─'*70}")
print(f"""
Our code (line 142-144):
Q = self.W_q(x) # x × W_q — "what am I looking for?"
K = self.W_k(x) # x × W_k — "what do I advertise as?"
V = self.W_v(x) # x × W_v — "what content do I provide?"
Same input x, three different weight matrices, three different outputs.
Each weight matrix is {d_model}×{d_model} = {d_model*d_model} learned parameters.
""")
Q_full = W_q(x) # [1, 5, 64]
K_full = W_k(x) # [1, 5, 64]
V_full = W_v(x) # [1, 5, 64]
print(f" For token 'sat' (pos 2):")
print(f" Input vector: [{x[0,2,0]:.3f}, {x[0,2,1]:.3f}, {x[0,2,2]:.3f}, ...]")
print(f" × W_q → Q: [{Q_full[0,2,0]:.3f}, {Q_full[0,2,1]:.3f}, {Q_full[0,2,2]:.3f}, ...]")
print(f" × W_k → K: [{K_full[0,2,0]:.3f}, {K_full[0,2,1]:.3f}, {K_full[0,2,2]:.3f}, ...]")
print(f" × W_v → V: [{V_full[0,2,0]:.3f}, {V_full[0,2,1]:.3f}, {V_full[0,2,2]:.3f}, ...]")
print(f"""
Notice: Q, K, V are all DIFFERENT vectors derived from the same input.
- The Q for "sat" encodes what "sat" is looking for (maybe: a subject? a location?)
- The K for "sat" encodes how "sat" advertises itself (maybe: past-tense verb, positional)
- The V for "sat" encodes the information "sat" provides (maybe: action implies surface)
These differences are entirely learned during training. W_q, W_k, W_v are
trained so that Q·K similarities are high for tokens that SHOULD attend to
each other, and the V vectors carry the information that's useful to pass forward.
""")
# ---- Step 2: Reshape for multi-head attention ----
print(f"{'─'*70}")
print(f"STEP 2: Split into {n_heads} heads")
print(f"{'─'*70}")
Q = Q_full.view(B, T, n_heads, d_k).transpose(1, 2) # [1, 4, 5, 16]
K = K_full.view(B, T, n_heads, d_k).transpose(1, 2)
V = V_full.view(B, T, n_heads, d_k).transpose(1, 2)
print(f"""
Our code (line 142, the .view().transpose() part):
Q = W_q(x).view(B, T, n_heads, d_k).transpose(1, 2)
This reshapes from [1, 5, 64] → [1, 4, 5, 16]
↑ ↑ ↑ ↑ ↑ ↑ ↑
batch T d_model batch heads T d_per_head
The 64-dim vector is split into 4 heads of 16 dims each.
Each head operates independently — like 4 parallel attention mechanisms,
each looking at a different 16-dimensional "view" of the data.
Head 1 might learn to track syntax (subject-verb agreement)
Head 2 might learn to track position (what's nearby)
Head 3 might learn to track semantic role (agent, patient, location)
Head 4 might learn to track something else entirely
""")
print(f" Q shape: {list(Q.shape)} — {n_heads} heads, each processing {T} tokens of {d_k} dims")
# ---- Step 3: Compute attention scores ----
print(f"\n{'─'*70}")
print(f"STEP 3: Compute attention scores (Q · Kᵀ / √d_k)")
print(f"{'─'*70}")
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
# scores shape: [1, 4, 5, 5] — for each head, a 5×5 matrix
print(f"""
Our code (line 150):
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
This is the "library lookup" — each token's Query is compared against
every token's Key by dot product. High dot product = high relevance.
scores shape: {list(scores.shape)} — for each of {n_heads} heads, a {T}×{T} matrix
Entry (i, j) = "how much should token i attend to token j?"
""")
# Show the score matrix for head 0
words = sentence.split()
print(f" Score matrix for Head 1 (before masking):")
print(f" {'':12s}", end="")
for w in words:
print(f"{w:>8s}", end="")
print()
for i, w in enumerate(words):
print(f" {w:>10s} ", end="")
for j in range(T):
print(f"{scores[0, 0, i, j].item():8.2f}", end="")
print()
# ---- Step 4: Causal mask ----
print(f"\n{'─'*70}")
print(f"STEP 4: Apply causal mask (can't look at future tokens)")
print(f"{'─'*70}")
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores_masked = scores.clone()
scores_masked = scores_masked.masked_fill(mask, float("-inf"))
print(f"""
Our code (lines 154-155):
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
Future positions → -inf → 0 after softmax.
Token "sat" (pos 2) can look at "the" (0), "cat" (1), "sat" (2)
but NOT "on" (3) or "the" (4).
""")
print(f" Score matrix for Head 1 (after masking):")
print(f" {'':12s}", end="")
for w in words:
print(f"{w:>8s}", end="")
print()
for i, w in enumerate(words):
print(f" {w:>10s} ", end="")
for j in range(T):
val = scores_masked[0, 0, i, j].item()
if val == float("-inf"):
print(f"{'−∞':>8s}", end="")
else:
print(f"{val:8.2f}", end="")
print()
# ---- Step 5: Softmax → attention weights ----
print(f"\n{'─'*70}")
print(f"STEP 5: Softmax → attention weights (each row sums to 1)")
print(f"{'─'*70}")
attn_weights = F.softmax(scores_masked, dim=-1)
print(f"""
Our code (line 158):
attn_weights = F.softmax(scores, dim=-1)
Converts raw scores to probabilities. Each row sums to 1.0.
These weights tell us: for each token, what fraction of its
attention goes to each other token?
""")
print(f" Attention weights for Head 1:")
print(f" {'':12s}", end="")
for w in words:
print(f"{w:>8s}", end="")
print(" interpretation")
for i, w in enumerate(words):
print(f" {w:>10s} ", end="")
max_j = 0
max_val = 0
for j in range(T):
val = attn_weights[0, 0, i, j].item()
if val > max_val:
max_val = val
max_j = j
print(f"{val:8.1%}", end="")
print(f" ← mostly attends to \"{words[max_j]}\"")
# ---- Step 6: Weighted sum of Values ----
print(f"\n{'─'*70}")
print(f"STEP 6: Weighted sum of Values → attention output")
print(f"{'─'*70}")
attn_output = attn_weights @ V # [1, 4, 5, 16]
print(f"""
Our code (line 163):
attn_output = attn_weights @ V
This is the punchline. Each token's output is a weighted blend of
all attending tokens' VALUE vectors, using the attention weights.
For "the" at position 4 (Head 1):
""")
# Show the blending for position 4 (the last "the")
print(f" output[pos4] = ", end="")
for j, w in enumerate(words):
weight = attn_weights[0, 0, 4, j].item()
if weight > 0.01:
print(f"{weight:.1%} × V(\"{w}\")", end="")
if j < T - 1 and any(attn_weights[0, 0, 4, k].item() > 0.01 for k in range(j+1, T)):
print(f" + ", end="")
print()
print(f"""
The output vector for "the" at position 4 now contains information
from the tokens it attended to — blended according to relevance.
This is how information flows between tokens in the sequence.
""")
# ---- Step 7: Concatenate heads and project ----
print(f"{'─'*70}")
print(f"STEP 7: Concatenate all heads → final output")
print(f"{'─'*70}")
# Concatenate heads: [1, 4, 5, 16] → [1, 5, 64]
concat = attn_output.transpose(1, 2).contiguous().view(B, T, d_model)
final_output = W_out(concat)
print(f"""
Our code (lines 166-167):
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(attn_output)
4 heads of 16 dims → concatenated to 64 dims → one more projection.
attn_output: {list(attn_output.shape)} → concat: {list(concat.shape)} → final: {list(final_output.shape)}
Each head's 16-dim output is concatenated, then W_out mixes them
together. This lets the model combine information across heads.
""")
# ---- Summary: What happened ----
print(f"{'─'*70}")
print(f"SUMMARY: What attention did to each token")
print(f"{'─'*70}")
for i, w in enumerate(words):
input_vec = x[0, i]
output_vec = final_output[0, i]
delta = (output_vec - input_vec).norm().item()
cosine = F.cosine_similarity(input_vec.unsqueeze(0), output_vec.unsqueeze(0)).item()
print(f" \"{w}\" (pos {i}):")
print(f" Input: [{input_vec[0]:.3f}, {input_vec[1]:.3f}, {input_vec[2]:.3f}, ...]")
print(f" Output: [{output_vec[0]:.3f}, {output_vec[1]:.3f}, {output_vec[2]:.3f}, ...]")
print(f" Change: ‖delta‖ = {delta:.3f}, cosine similarity = {cosine:.3f}")
# Show what this token attended to most (averaged across heads)
avg_attn = attn_weights[0, :, i, :].mean(dim=0) # average across 4 heads
top_vals, top_idx = torch.topk(avg_attn, min(3, i + 1))
attended = [f"\"{words[j]}\"({v:.0%})" for j, v in zip(top_idx.tolist(), top_vals.tolist())]
print(f" Attended to: {', '.join(attended)}")
print()
print(f"""
KEY INSIGHT: The output vectors are DIFFERENT from the input vectors.
Each token has been enriched with information from the tokens it
attended to. "the" at position 4 now carries information about
"cat", "sat", "on" — not just the word "the" in isolation.
This is the output of ONE attention layer. This enriched vector then
goes to the FFN (step 5), gets added back via residual connection
(step 6), and then the NEXT layer's attention does the same thing
again — but now working with already-enriched vectors.
IMPORTANT: The W_q, W_k, W_v, W_out weight matrices are FROZEN at
inference time. They were learned during training. What we just
traced is what happens every time this layer processes any input.
The weights don't change — only the input vectors change.
""")
"""
How W_q, W_k, W_v Get Trained — Connecting Inference to Training
The punchline: there is NO special code for training attention weights.
Two lines update EVERY weight in the model — W_q, W_k, W_v, W_out,
FFN weights, embeddings, everything:
loss.backward() # compute gradients for ALL weights
optimizer.step() # nudge ALL weights
This file proves it by showing the actual W_q weights before and after
a single training step, and printing the gradients that drive the change.
Run with:
python qkv_training_connection.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)
# =====================================================================
# The model (condensed from transformer_from_scratch.py)
# =====================================================================
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # ← THESE are what we're tracking
self.W_k = nn.Linear(d_model, d_model) # ←
self.W_v = nn.Linear(d_model, d_model) # ←
self.W_out = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(out)
class FeedForward(nn.Module):
def __init__(self, d_model):
super().__init__()
self.up = nn.Linear(d_model, 4 * d_model)
self.down = nn.Linear(4 * d_model, d_model)
def forward(self, x):
return self.down(F.gelu(self.up(x)))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4, max_seq_len=32):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx):
B, T = idx.shape
x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
for block in self.blocks:
x = block(x)
return self.lm_head(self.ln_f(x))
# =====================================================================
# Setup
# =====================================================================
vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
"a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}
model = Transformer(vocab_size=len(vocab), d_model=64, n_heads=4, n_layers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
# =====================================================================
# THE WALKTHROUGH
# =====================================================================
print("=" * 70)
print(" HOW W_q, W_k, W_v GET TRAINED")
print("=" * 70)
# ---- Show all named parameters ----
print(f"\n{'─'*70}")
print(f"FIRST: What parameters does the model have?")
print(f"{'─'*70}")
print(f"\n model.named_parameters() lists EVERY learnable weight:\n")
total = 0
attn_params = 0
for name, param in model.named_parameters():
n = param.numel()
total += n
is_attn = "W_q" in name or "W_k" in name or "W_v" in name or "W_out" in name
if is_attn:
attn_params += n
marker = " ← Q/K/V WEIGHT" if ("W_q" in name or "W_k" in name or "W_v" in name) else ""
print(f" {name:45s} shape {str(list(param.shape)):18s} ({n:>6,} params){marker}")
print(f"\n Total: {total:,} parameters")
print(f" Attention (Q/K/V/Out): {attn_params:,} parameters ({attn_params/total:.0%} of total)")
print(f"""
Notice: W_q, W_k, W_v appear in every block (0, 1, 2, 3).
They are regular nn.Linear layers — nothing special about them.
PyTorch tracks them automatically because they're attributes of nn.Module.
""")
# ---- Snapshot W_q before training ----
print(f"{'─'*70}")
print(f"BEFORE TRAINING: Snapshot of Layer 0's W_q weights")
print(f"{'─'*70}")
# Get the W_q weight matrix from block 0
wq_before = model.blocks[0].attn.W_q.weight.data.clone()
print(f"\n model.blocks[0].attn.W_q.weight — shape: {list(wq_before.shape)}")
print(f" This is a {wq_before.shape[0]}×{wq_before.shape[1]} matrix of learned numbers.\n")
print(f" First 5×5 corner (randomly initialized):")
for i in range(5):
print(f" [{', '.join(f'{wq_before[i,j]:+.4f}' for j in range(5))}, ...]")
print(f"\n These are RANDOM — the model hasn't learned anything yet.")
# ---- The training input ----
print(f"\n{'─'*70}")
print(f"THE TRAINING EXAMPLE")
print(f"{'─'*70}")
sentence = "the cat sat on the mat"
words = sentence.split()
input_ids = torch.tensor([[word_to_id[w] for w in words[:-1]]]) # "the cat sat on the"
target_ids = torch.tensor([[word_to_id[w] for w in words[1:]]]) # "cat sat on the mat"
print(f"\n Sentence: \"{sentence}\"")
print(f" Input: {[id_to_word[i] for i in input_ids[0].tolist()]}")
print(f" Target: {[id_to_word[i] for i in target_ids[0].tolist()]}")
print(f" (Each input token should predict the next token)")
# ---- STEP 3: Forward pass ----
print(f"\n{'─'*70}")
print(f"STEP 3: Forward pass — the input flows through W_q, W_k, W_v")
print(f"{'─'*70}")
model.train()
logits = model(input_ids)
print(f"""
logits = model(input_ids)
What just happened inside:
1. input_ids → embedding layer → 5 vectors of dim 64
2. Those vectors enter block 0:
a. LayerNorm
b. CausalSelfAttention:
Q = W_q(x) ← the vector was multiplied by W_q's 64×64 matrix
K = W_k(x) ← same vector, different matrix
V = W_v(x) ← same vector, third matrix
scores = Q · Kᵀ / √16
attention_weights = softmax(masked scores)
output = attention_weights · V
output = W_out(output)
c. Residual addition: x = x + attention_output
d. FFN (expand → GELU → compress)
e. Residual addition: x = x + ffn_output
3. Repeat for blocks 1, 2, 3
4. Final LayerNorm → lm_head → logits
W_q was USED in the forward pass but NOT CHANGED.
At this point, W_q is still the exact same random matrix.
""")
# ---- STEP 4: Loss ----
print(f"{'─'*70}")
print(f"STEP 4: Cross-entropy loss")
print(f"{'─'*70}")
B, T, V = logits.shape
loss = F.cross_entropy(logits.view(B * T, V), target_ids.view(B * T))
print(f"\n loss = F.cross_entropy(logits, targets)")
print(f" loss = {loss.item():.4f}")
print(f"\n This single number measures: 'how wrong was the model across all 5 positions?'")
print(f" High loss = model's W_q/W_k/W_v produced bad attention patterns")
print(f" that led to bad predictions.")
# ---- STEP 5: Backpropagation ----
print(f"\n{'─'*70}")
print(f"STEP 5: loss.backward() — THIS is where W_q gets its gradient")
print(f"{'─'*70}")
# Before backward, there are no gradients
print(f"\n Before loss.backward():")
print(f" model.blocks[0].attn.W_q.weight.grad = {model.blocks[0].attn.W_q.weight.grad}")
optimizer.zero_grad()
loss.backward()
# Now every parameter has a gradient!
print(f"\n After loss.backward():")
wq_grad = model.blocks[0].attn.W_q.weight.grad
wk_grad = model.blocks[0].attn.W_k.weight.grad
wv_grad = model.blocks[0].attn.W_v.weight.grad
print(f" W_q.grad shape: {list(wq_grad.shape)} — a gradient for EVERY weight in W_q")
print(f" W_k.grad shape: {list(wk_grad.shape)}")
print(f" W_v.grad shape: {list(wv_grad.shape)}")
print(f"\n W_q gradient — first 5×5 corner:")
for i in range(5):
print(f" [{', '.join(f'{wq_grad[i,j]:+.6f}' for j in range(5))}, ...]")
print(f"""
WHAT IS THIS GRADIENT?
Each number answers: "if I increased this specific weight by a tiny amount,
how much would the loss change?"
Positive gradient (+0.003) → increasing this weight increases loss → DECREASE it
Negative gradient (-0.005) → increasing this weight decreases loss → INCREASE it
Near-zero gradient (0.0001) → this weight barely matters for this example
loss.backward() computed this by tracing the chain rule BACKWARD:
loss ← logits ← lm_head ← block3 ← ... ← block0.attn ← W_q × input
↑
∂loss/∂W_q
The gradient flows backward through every operation, all the way from
the loss to W_q. The residual connections act as a "gradient highway"
ensuring the signal stays strong even for W_q in the very first layer.
""")
# Show gradients for ALL attention parameters
print(f" Gradient statistics for all Q/K/V weights in the model:")
print(f" {'Parameter':45s} {'grad mean':>12s} {'grad std':>12s}")
print(f" {'─'*45} {'─'*12} {'─'*12}")
for name, param in model.named_parameters():
if param.grad is not None and ("W_q" in name or "W_k" in name or "W_v" in name):
g = param.grad
print(f" {name:45s} {g.mean():+12.6f} {g.std():12.6f}")
print(f"""
Every W_q, W_k, W_v in every layer got a gradient. No special code —
PyTorch's autograd traced the computation graph automatically.
""")
# ---- STEP 6: Gradient descent ----
print(f"{'─'*70}")
print(f"STEP 6: optimizer.step() — W_q weights actually change")
print(f"{'─'*70}")
print(f"\n W_q BEFORE optimizer.step() (first 5 values of row 0):")
print(f" [{', '.join(f'{model.blocks[0].attn.W_q.weight.data[0,j]:+.6f}' for j in range(5))}]")
optimizer.step()
print(f"\n W_q AFTER optimizer.step() (first 5 values of row 0):")
print(f" [{', '.join(f'{model.blocks[0].attn.W_q.weight.data[0,j]:+.6f}' for j in range(5))}]")
wq_after = model.blocks[0].attn.W_q.weight.data.clone()
delta = (wq_after - wq_before).abs()
print(f"\n The difference (first 5):")
print(f" [{', '.join(f'{delta[0,j]:.6f}' for j in range(5))}]")
print(f"\n Average change per weight: {delta.mean():.6f}")
print(f" Max change: {delta.max():.6f}")
print(f"""
TINY changes — on the order of 0.001. But this happens for:
• Every weight in W_q ({wq_before.numel():,} weights)
• Every weight in W_k ({wq_before.numel():,} weights)
• Every weight in W_v ({wq_before.numel():,} weights)
• Every weight in W_out, FFN, embeddings, etc.
• Total: {sum(p.numel() for p in model.parameters()):,} weights updated simultaneously
And then we do this AGAIN for the next batch, and the next, and the next,
for millions of batches. Each tiny nudge accumulates. After enough steps,
W_q has been sculpted from random noise into a matrix that projects tokens
into a Query space where meaningful similarity comparisons emerge.
""")
# ---- The connection ----
print(f"{'─'*70}")
print(f"THE FULL PICTURE: Inference ← Training")
print(f"{'─'*70}")
print(f"""
TRAINING (what we just traced):
1. Forward pass: input → W_q/W_k/W_v → attention → logits
2. Loss: how wrong were the predictions?
3. loss.backward(): compute ∂loss/∂W_q, ∂loss/∂W_k, ∂loss/∂W_v
4. optimizer.step(): nudge each weight to reduce loss
5. Repeat millions of times
INFERENCE (the QKV walkthrough):
1. Forward pass: input → W_q/W_k/W_v → attention → logits → sample
2. That's it. Same matrices, same computation, no updates.
The W_q matrix used at inference IS the W_q matrix that was nudged
millions of times during training. Every entry in that 64×64 matrix
was shaped by gradient descent to produce Query vectors where the
right tokens attend to each other.
There is no separate "attention training" or "Q/K/V learning" step.
It's all just: forward pass, loss, backward, step. The entire model —
embeddings, attention, FFN, output head — learns together, end to end,
from one unified signal: "predict the next token better."
""")
"""
KV Cache — Concrete Example
Shows exactly what gets cached and reused during autoregressive generation.
We trace through generating 3 tokens, showing how K and V vectors
accumulate in the cache and avoid redundant computation.
Run with:
python kv_cache_walkthrough.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)
# =====================================================================
# Minimal setup
# =====================================================================
vocab = sorted(["the", "cat", "sat", "on", "mat", "dog", "rug", "big",
"a", "ran", "to", "slept", "bed", "house"])
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for i, w in enumerate(vocab)}
d_model = 64
n_heads = 4
d_k = d_model // n_heads # 16
# One attention layer's components
token_emb = nn.Embedding(len(vocab), d_model)
pos_emb = nn.Embedding(32, d_model)
W_q = nn.Linear(d_model, d_model)
W_k = nn.Linear(d_model, d_model)
W_v = nn.Linear(d_model, d_model)
W_out = nn.Linear(d_model, d_model)
def embed(token_ids, positions):
"""Embed tokens with position info."""
return token_emb(token_ids) + pos_emb(positions)
def attention_full(x):
"""Standard attention — recompute Q, K, V for ALL tokens every time."""
B, T, C = x.shape
Q = W_q(x).view(B, T, n_heads, d_k).transpose(1, 2)
K = W_k(x).view(B, T, n_heads, d_k).transpose(1, 2)
V = W_v(x).view(B, T, n_heads, d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return W_out(out), K, V
def attention_cached(x_new, pos, kv_cache):
"""
Cached attention — only compute Q, K, V for the NEW token.
Reuse cached K, V from all previous tokens.
"""
B, T_new, C = x_new.shape # T_new = 1 (just the new token)
# Only compute Q, K, V for the new token
Q_new = W_q(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2) # [1, 4, 1, 16]
K_new = W_k(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2) # [1, 4, 1, 16]
V_new = W_v(x_new).view(B, T_new, n_heads, d_k).transpose(1, 2) # [1, 4, 1, 16]
# Append new K, V to cache
if kv_cache is not None:
K_cached, V_cached = kv_cache
K_full = torch.cat([K_cached, K_new], dim=2) # [1, 4, T_old+1, 16]
V_full = torch.cat([V_cached, V_new], dim=2) # [1, 4, T_old+1, 16]
else:
K_full = K_new
V_full = V_new
# Attention: new token's Q against ALL K's (cached + new)
T_full = K_full.shape[2]
scores = (Q_new @ K_full.transpose(-2, -1)) / math.sqrt(d_k) # [1, 4, 1, T_full]
# No mask needed — the new token is always the LAST position,
# so it can attend to everything before it (all cached) + itself
attn = F.softmax(scores, dim=-1) # [1, 4, 1, T_full]
out = (attn @ V_full).transpose(1, 2).contiguous().view(B, 1, C) # [1, 1, 64]
return W_out(out), (K_full, V_full)
# =====================================================================
# THE WALKTHROUGH
# =====================================================================
print("=" * 70)
print(" KV CACHE WALKTHROUGH")
print("=" * 70)
prompt = "the cat sat on the"
prompt_ids = [word_to_id[w] for w in prompt.split()]
print(f"\n Prompt: \"{prompt}\"")
print(f" Token IDs: {prompt_ids}")
print(f" We'll generate 3 tokens after this prompt.\n")
# =====================================================================
# METHOD 1: No cache (recompute everything each time)
# =====================================================================
print(f"{'─'*70}")
print(f" METHOD 1: WITHOUT KV CACHE")
print(f"{'─'*70}")
generated_no_cache = list(prompt_ids)
total_qkv_ops_no_cache = 0
for gen_step in range(3):
T = len(generated_no_cache)
ids = torch.tensor([generated_no_cache])
positions = torch.arange(T)
x = embed(ids, positions) # [1, T, 64]
out, K, V = attention_full(x)
# Count Q, K, V computations (per token)
qkv_ops = T * 3 # Q, K, V each computed for all T tokens
total_qkv_ops_no_cache += qkv_ops
# Sample next token (just take argmax for simplicity)
# In real inference, only the last position's output matters
last_vec = out[0, -1, :] # [64] — last position only
# Fake a simple "prediction" by picking a token
# (we're not running the full model, just demonstrating the cache)
next_token = (gen_step + 6) % len(vocab) # deterministic for demo
generated_no_cache.append(next_token)
print(f"\n Step {gen_step + 1}: generating token {gen_step + 6}")
print(f" Sequence length: {T} tokens")
print(f" Q computed for: ALL {T} tokens ({T} × W_q multiply)")
print(f" K computed for: ALL {T} tokens ({T} × W_k multiply)")
print(f" V computed for: ALL {T} tokens ({T} × W_v multiply)")
print(f" Total Q/K/V ops this step: {qkv_ops}")
print(f" But we only USE position {T-1}'s output (the last one)")
print(f" Positions 0–{T-2}'s Q, K, V → same as last step, wasted work!")
print(f"\n Total Q/K/V operations: {total_qkv_ops_no_cache}")
# =====================================================================
# METHOD 2: With KV cache
# =====================================================================
print(f"\n{'─'*70}")
print(f" METHOD 2: WITH KV CACHE")
print(f"{'─'*70}")
# First: process the full prompt (no cache yet)
T = len(prompt_ids)
ids = torch.tensor([prompt_ids])
x = embed(ids, torch.arange(T))
out, K_initial, V_initial = attention_full(x)
kv_cache = (K_initial, V_initial) # Cache K and V from the prompt
print(f"\n Initial: process full prompt ({T} tokens)")
print(f" Q/K/V computed for all {T} tokens (no cache yet)")
print(f" K cache shape: {list(kv_cache[0].shape)}")
print(f" → {kv_cache[0].shape[2]} tokens × {n_heads} heads × {d_k} dims cached")
total_qkv_ops_cache = T * 3 # Initial prompt processing
generated_cached = list(prompt_ids)
for gen_step in range(3):
# Only embed the ONE new token
new_token_id = (gen_step + 6) % len(vocab)
new_pos = len(generated_cached)
generated_cached.append(new_token_id)
x_new = embed(
torch.tensor([[new_token_id]]), # [1, 1] — just one token
torch.tensor([new_pos]), # its position
) # [1, 1, 64]
out_new, kv_cache = attention_cached(x_new, new_pos, kv_cache)
qkv_ops = 1 * 3 # Q, K, V computed for just 1 token!
total_qkv_ops_cache += qkv_ops
K_cached, V_cached = kv_cache
T_total = K_cached.shape[2]
print(f"\n Step {gen_step + 1}: generating token {gen_step + 6}")
print(f" New token embedded: {id_to_word.get(new_token_id, '?')} at position {new_pos}")
print(f" Q computed for: 1 token (just the new one)")
print(f" K computed for: 1 token (just the new one)")
print(f" V computed for: 1 token (just the new one)")
print(f" K cache: {T_total - 1} old + 1 new = {T_total} total")
print(f" V cache: {T_total - 1} old + 1 new = {T_total} total")
print(f" Attention: new Q[1,4,1,16] @ all K[1,4,{T_total},16]^T = scores[1,4,1,{T_total}]")
print(f" → new token attends to ALL {T_total} tokens (using cached K's)")
print(f" Total Q/K/V ops this step: {qkv_ops}")
print(f"\n Total Q/K/V operations: {total_qkv_ops_cache}")
# =====================================================================
# Comparison
# =====================================================================
print(f"\n{'─'*70}")
print(f" COMPARISON")
print(f"{'─'*70}")
print(f"""
Generating 3 tokens after a 5-token prompt:
Without cache:
Step 1: compute Q/K/V for 5 tokens = 15 ops
Step 2: compute Q/K/V for 6 tokens = 18 ops
Step 3: compute Q/K/V for 7 tokens = 21 ops
Total: {total_qkv_ops_no_cache} ops
With cache:
Initial: compute Q/K/V for 5 tokens = 15 ops (prompt, once)
Step 1: compute Q/K/V for 1 token = 3 ops (reuse 5 cached)
Step 2: compute Q/K/V for 1 token = 3 ops (reuse 6 cached)
Step 3: compute Q/K/V for 1 token = 3 ops (reuse 7 cached)
Total: {total_qkv_ops_cache} ops
Savings: {total_qkv_ops_no_cache - total_qkv_ops_cache} fewer Q/K/V computations ({(1 - total_qkv_ops_cache/total_qkv_ops_no_cache):.0%} less)
This gap grows with sequence length. Generating 500 tokens:
Without cache: 3×(5+6+7+...+504) = {3 * sum(range(5, 505)):,} ops
With cache: 3×5 + 3×500 = {3*5 + 3*500:,} ops
Savings: {(1 - (3*5 + 3*500) / (3 * sum(range(5, 505)))):.1%} less computation
""")
# =====================================================================
# What the cache actually looks like in memory
# =====================================================================
print(f"{'─'*70}")
print(f" WHAT THE CACHE LOOKS LIKE IN MEMORY")
print(f"{'─'*70}")
K_cached, V_cached = kv_cache
print(f"\n After generating 3 tokens (8 total in sequence):")
print(f"\n K cache shape: {list(K_cached.shape)}")
print(f" = [batch=1, heads={n_heads}, tokens={K_cached.shape[2]}, d_k={d_k}]")
print(f"\n V cache shape: {list(V_cached.shape)}")
print(f" = [batch=1, heads={n_heads}, tokens={V_cached.shape[2]}, d_k={d_k}]")
total_cached_numbers = K_cached.numel() + V_cached.numel()
print(f"\n Total numbers in cache: {total_cached_numbers:,}")
print(f" = 2 (K+V) × {n_heads} heads × {K_cached.shape[2]} tokens × {d_k} d_k")
print(f"""
In a real model (e.g., Llama 7B):
d_model = 4096, n_heads = 32, d_k = 128, n_layers = 32
For a 2048-token sequence, the KV cache holds:
2 × 32 heads × 2048 tokens × 128 d_k × 32 layers × 2 bytes (fp16)
= ~1 GB of GPU memory just for the cache
This is why long conversations are expensive — the KV cache grows
with every token generated and must stay in fast GPU memory.
It's also why context length limits exist — the cache eventually
exceeds available memory.
""")
# =====================================================================
# Important subtlety: Q is NOT cached
# =====================================================================
print(f"{'─'*70}")
print(f" WHY NOT CACHE Q?")
print(f"{'─'*70}")
print(f"""
Only K and V are cached. Q is not. Here's why:
At each generation step, we need:
scores = Q_new @ K_all.T (new token's Q against ALL keys)
output = attn @ V_all (attention weights against ALL values)
Q_new: we only need the NEW token's query — it asks "what should I
attend to?" Only the new token needs to ask this question.
Previous tokens' queries were used in their own steps and
are no longer needed.
K_all: we need ALL tokens' keys — the new token might want to
attend to any previous token. Old keys don't change, so
we cache them.
V_all: we need ALL tokens' values — once attention decides who
to attend to, it needs to read their content. Old values
don't change, so we cache them.
The asymmetry:
Q is consumed once (by the token that generated it)
K and V are consumed repeatedly (by every future token)
""")
"""
Mixture of Experts — The MoE Variant of Our Transformer
This adds MoE to the same architecture from transformer_from_scratch.py.
The ONLY change: the single FFN in each layer is replaced by multiple
expert FFNs + a learned router. Everything else is identical.
Usage:
pip install torch
python transformer_moe.py
Compares a standard dense transformer to an MoE variant on the same
training data, showing how MoE achieves similar quality with more
parameters but less compute per token.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
# =============================================================================
# STANDARD COMPONENTS (same as base transformer)
# =============================================================================
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_out = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_out(out)
class FeedForward(nn.Module):
"""Standard single FFN — same as before."""
def __init__(self, d_model):
super().__init__()
self.up = nn.Linear(d_model, 4 * d_model)
self.down = nn.Linear(4 * d_model, d_model)
def forward(self, x):
return self.down(F.gelu(self.up(x)))
# =============================================================================
# THE MoE LAYER — this is the new part
# =============================================================================
class MoELayer(nn.Module):
"""
Mixture of Experts: replaces a single FFN with multiple expert FFNs
and a learned router.
Architecture:
1. Router: a single linear layer that scores each expert for each token
2. Top-K selection: pick the K highest-scoring experts
3. Expert computation: run only the selected experts
4. Weighted combination: blend expert outputs using router scores
Parameters:
d_model: dimension of the residual stream (e.g., 64)
n_experts: total number of expert FFNs (e.g., 8)
top_k: how many experts each token uses (e.g., 2)
"""
def __init__(self, d_model: int, n_experts: int = 8, top_k: int = 2):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# THE ROUTER: a single linear layer
# Input: token vector [d_model]
# Output: one score per expert [n_experts]
# This is the ONLY new learned component. Everything else is just
# multiple copies of the same FFN architecture.
self.router = nn.Linear(d_model, n_experts, bias=False)
# THE EXPERTS: n_experts independent FFNs
# Each one is identical in architecture to the standard FFN,
# but they have different learned weights.
self.experts = nn.ModuleList([
FeedForward(d_model) for _ in range(n_experts)
])
def forward(self, x: torch.Tensor):
"""
x: [batch, seq_len, d_model]
returns: (output, aux_loss)
output: [batch, seq_len, d_model] — the blended expert outputs
aux_loss: scalar — load balancing loss to keep experts evenly used
"""
B, T, D = x.shape
# ---- STEP 1: ROUTER SCORES ----
# Each token gets a score for each expert
router_logits = self.router(x) # [B, T, n_experts]
router_probs = F.softmax(router_logits, dim=-1) # normalize to probabilities
# ---- STEP 2: TOP-K SELECTION ----
# Pick the top_k experts with highest scores for each token
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
# top_k_probs: [B, T, top_k] — the weights for combining
# top_k_indices: [B, T, top_k] — which experts were selected
# Normalize the top-k probabilities so they sum to 1
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# ---- STEP 3 & 4: RUN SELECTED EXPERTS AND COMBINE ----
# For each token, run only its selected experts and blend the results.
#
# In production, this is heavily optimized with custom CUDA kernels.
# Here we use a simple loop for clarity.
output = torch.zeros_like(x)
for k in range(self.top_k):
# Which expert does each token use for this slot?
expert_indices = top_k_indices[:, :, k] # [B, T]
weights = top_k_probs[:, :, k] # [B, T]
# Run each expert on the tokens assigned to it
for e_idx in range(self.n_experts):
# Find which tokens selected this expert in this slot
token_mask = (expert_indices == e_idx) # [B, T] boolean
if token_mask.any():
# Get those tokens, run through expert, weight by router score
expert_input = x[token_mask] # [num_selected, D]
expert_output = self.experts[e_idx](expert_input) # [num_selected, D]
expert_weights = weights[token_mask].unsqueeze(-1) # [num_selected, 1]
output[token_mask] += expert_weights * expert_output
# ---- LOAD BALANCING LOSS ----
# Without this, the router collapses to always picking the same 1-2 experts.
# We penalize uneven distribution of tokens across experts.
#
# For each expert: what fraction of tokens were routed to it?
# Ideal: each expert gets 1/n_experts of the tokens.
# Actual: compute from router probabilities.
# Fraction of tokens assigned to each expert (from hard top-k assignment)
# Using a soft approximation for differentiability
expert_usage = router_probs.mean(dim=(0, 1)) # [n_experts]
# Ideal uniform distribution
target = torch.ones_like(expert_usage) / self.n_experts
# Auxiliary loss: penalize deviation from uniform
aux_loss = self.n_experts * (expert_usage * target).sum()
# (Scaled so the loss is O(1) regardless of n_experts)
return output, aux_loss
# =============================================================================
# TRANSFORMER BLOCKS: STANDARD vs MoE
# =============================================================================
class StandardBlock(nn.Module):
"""Standard transformer block with single FFN."""
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x, torch.tensor(0.0) # no aux loss
class MoEBlock(nn.Module):
"""
MoE transformer block — identical to StandardBlock except
the FFN is replaced by the MoE layer.
Attention is unchanged. Residual connections are unchanged.
The only swap: single FFN → router + multiple expert FFNs.
"""
def __init__(self, d_model, n_heads, n_experts=8, top_k=2):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.moe = MoELayer(d_model, n_experts, top_k) # ← THE ONLY CHANGE
def forward(self, x):
x = x + self.attn(self.ln1(x))
moe_out, aux_loss = self.moe(self.ln2(x))
x = x + moe_out # residual, same as before
return x, aux_loss
# =============================================================================
# FULL MODELS
# =============================================================================
class Transformer(nn.Module):
"""Configurable transformer: can use standard or MoE blocks."""
def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=4,
max_seq_len=32, use_moe=False, n_experts=8, top_k=2):
super().__init__()
self.use_moe = use_moe
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
if use_moe:
self.blocks = nn.ModuleList([
MoEBlock(d_model, n_heads, n_experts, top_k)
for _ in range(n_layers)
])
else:
self.blocks = nn.ModuleList([
StandardBlock(d_model, n_heads)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx):
B, T = idx.shape
x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
total_aux_loss = torch.tensor(0.0, device=idx.device)
for block in self.blocks:
x, aux_loss = block(x)
total_aux_loss = total_aux_loss + aux_loss
return self.lm_head(self.ln_f(x)), total_aux_loss
# =============================================================================
# TOKENIZER + DATA (same as training file)
# =============================================================================
class SimpleTokenizer:
def __init__(self, vocab):
self.word_to_id = {w: i for i, w in enumerate(vocab)}
self.id_to_word = {i: w for i, w in enumerate(vocab)}
self.vocab_size = len(vocab)
def encode(self, text):
return [self.word_to_id[w] for w in text.lower().split() if w in self.word_to_id]
def decode(self, ids):
return " ".join(self.id_to_word[i] for i in ids)
CORPUS = [
"the cat sat on the mat", "the dog sat on the rug",
"the cat slept on the bed", "the dog ran to the house",
"a big cat sat on a big mat", "a small dog ran to a red house",
"the quick brown fox jumped over the lazy dog",
"the lazy cat slept under the big tree",
"a happy bird sat on the tree", "the red bird flew over the house",
"the brown dog slept on the mat", "a quick cat ran under the tree",
"the happy dog jumped over the mat", "a lazy bird sat on the red house",
"the small cat ran to the big tree", "a brown fox slept under the mat",
"the big dog sat on a red rug", "a happy cat jumped over the lazy dog",
"the quick bird flew to the sun", "a small fox ran under the big house",
]
def build_vocab(corpus):
words = set()
for s in corpus:
for w in s.lower().split():
words.add(w)
return sorted(words)
def make_pairs(corpus, tokenizer, seq_len=8):
inputs, targets = [], []
for s in corpus:
ids = tokenizer.encode(s)
for start in range(0, len(ids) - 1, seq_len):
end = min(start + seq_len, len(ids) - 1)
if end - start < 2:
continue
inputs.append(ids[start:end])
targets.append(ids[start+1:end+1])
return inputs, targets
def collate(inputs, targets, batch_size):
indices = list(range(len(inputs)))
random.shuffle(indices)
batches = []
for i in range(0, len(indices), batch_size):
batch_idx = indices[i:i+batch_size]
b_inp = [inputs[j] for j in batch_idx]
b_tgt = [targets[j] for j in batch_idx]
max_len = max(len(s) for s in b_inp)
padded_inp = [s + [0]*(max_len-len(s)) for s in b_inp]
padded_tgt = [s + [0]*(max_len-len(s)) for s in b_tgt]
batches.append((torch.tensor(padded_inp), torch.tensor(padded_tgt)))
return batches
# =============================================================================
# TRAINING
# =============================================================================
def train_model(model, tokenizer, corpus, n_epochs=150, lr=3e-3,
batch_size=8, aux_loss_weight=0.01, label="Model"):
"""Train a model and return loss history."""
inputs, targets = make_pairs(corpus, tokenizer)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_history = []
for epoch in range(n_epochs):
batches = collate(inputs, targets, batch_size)
epoch_loss = 0.0
for inp, tgt in batches:
model.train()
logits, aux_loss = model(inp)
B, T, V = logits.shape
# Main loss: cross-entropy on next-token prediction
main_loss = F.cross_entropy(logits.view(B*T, V), tgt.view(B*T))
# Total loss: main + weighted auxiliary (load balancing) loss
loss = main_loss + aux_loss_weight * aux_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += main_loss.item()
avg = epoch_loss / len(batches)
loss_history.append(avg)
if epoch == 0 or epoch == 9 or epoch % 25 == 24 or epoch == n_epochs - 1:
print(f" [{label}] Epoch {epoch+1:4d}/{n_epochs} loss={avg:.4f}")
return loss_history
@torch.no_grad()
def show_predictions(model, tokenizer, prompt, label=""):
model.eval()
ids = tokenizer.encode(prompt)
logits, _ = model(torch.tensor([ids]))
probs = F.softmax(logits[0, -1, :], dim=-1)
top_probs, top_ids = torch.topk(probs, 5)
print(f" {label}\"{prompt}\" → next token:")
for p, idx in zip(top_probs, top_ids):
w = tokenizer.id_to_word[idx.item()]
bar = "█" * int(p.item() * 40)
print(f" {p.item():6.1%} {w:10s} {bar}")
print()
@torch.no_grad()
def show_expert_routing(model, tokenizer, sentence):
"""Show which experts each token in a sentence gets routed to."""
model.eval()
ids = tokenizer.encode(sentence)
words = sentence.lower().split()
x = model.token_emb(torch.tensor([ids])) + model.pos_emb(torch.arange(len(ids)))
x = x.unsqueeze(0) if x.dim() == 2 else x
print(f" Expert routing for: \"{sentence}\"")
print(f" {'Token':12s} {'Layer 1':18s} {'Layer 2':18s} {'Layer 3':18s} {'Layer 4':18s}")
print(f" {'─'*12} {'─'*18} {'─'*18} {'─'*18} {'─'*18}")
# Run through each layer, tracking routing
for block in model.blocks:
x_norm = block.ln1(x)
x = x + block.attn(x_norm)
x_norm2 = block.ln2(x)
# Get router decisions
router_logits = block.moe.router(x_norm2)
router_probs = F.softmax(router_logits, dim=-1)
top_probs, top_idx = torch.topk(router_probs, block.moe.top_k, dim=-1)
# Store for display
if not hasattr(block, '_routing_info'):
block._routing_info = []
block._routing_info = list(zip(
top_idx[0].tolist(),
top_probs[0].tolist()
))
# Actually run the MoE to advance x
moe_out, _ = block.moe(x_norm2)
x = x + moe_out
# Print routing table
for t in range(len(words)):
line = f" {words[t]:12s}"
for block in model.blocks:
indices, probs = block._routing_info[t]
e1, e2 = indices
p1, p2 = probs
line += f" E{e1+1}({p1:.0%})+E{e2+1}({p2:.0%}) "
print(line)
# Clean up
for block in model.blocks:
if hasattr(block, '_routing_info'):
del block._routing_info
print()
# =============================================================================
# MAIN: Compare Standard vs MoE
# =============================================================================
if __name__ == "__main__":
print("=" * 64)
print(" STANDARD TRANSFORMER vs MIXTURE OF EXPERTS")
print("=" * 64)
vocab = build_vocab(CORPUS)
tokenizer = SimpleTokenizer(vocab)
print(f"\n Vocabulary: {tokenizer.vocab_size} words")
print(f" Corpus: {len(CORPUS)} sentences\n")
# Build both models
d_model = 64
n_heads = 4
n_layers = 4
standard = Transformer(
tokenizer.vocab_size, d_model, n_heads, n_layers,
use_moe=False,
)
moe = Transformer(
tokenizer.vocab_size, d_model, n_heads, n_layers,
use_moe=True, n_experts=8, top_k=2,
)
# Count parameters
std_params = sum(p.numel() for p in standard.parameters())
moe_params = sum(p.numel() for p in moe.parameters())
moe_ffn_params = sum(p.numel() for n, p in moe.named_parameters() if "experts" in n)
moe_router_params = sum(p.numel() for n, p in moe.named_parameters() if "router" in n)
print(f"{'─'*64}")
print(f" PARAMETER COMPARISON")
print(f"{'─'*64}")
print(f" Standard: {std_params:>8,} total parameters")
print(f" MoE: {moe_params:>8,} total parameters ({moe_params/std_params:.1f}× more)")
print(f" ├─ Expert FFNs: {moe_ffn_params:,} (8 experts × standard FFN size)")
print(f" └─ Routers: {moe_router_params:,} (tiny — just one linear layer per block)")
print(f"\n But each token only activates 2/8 experts = ~{2/8:.0%} of FFN parameters")
print(f" So compute per token is much closer to Standard than the param count suggests.\n")
# Train both
print(f"{'─'*64}")
print(f" TRAINING: Standard Transformer")
print(f"{'─'*64}")
std_history = train_model(standard, tokenizer, CORPUS, n_epochs=150, label="Standard")
print(f"\n{'─'*64}")
print(f" TRAINING: MoE Transformer (8 experts, top-2)")
print(f"{'─'*64}")
moe_history = train_model(moe, tokenizer, CORPUS, n_epochs=150, label="MoE")
# Compare loss
print(f"\n{'─'*64}")
print(f" LOSS COMPARISON")
print(f"{'─'*64}")
print(f" Standard: start={std_history[0]:.4f} end={std_history[-1]:.4f}")
print(f" MoE: start={moe_history[0]:.4f} end={moe_history[-1]:.4f}")
# ASCII loss curve comparison
print(f"\n Loss curves (S=Standard, M=MoE):")
max_loss = max(max(std_history), max(moe_history))
for row in range(8, -1, -1):
threshold = max_loss * row / 8
line = f" {threshold:5.2f} │"
for i in range(0, 150, 3):
s = std_history[i] >= threshold
m = moe_history[i] >= threshold
if s and m:
line += "▓"
elif s:
line += "S"
elif m:
line += "M"
else:
line += " "
print(line)
print(f" └{'─' * 50}")
print(f" S=Standard M=MoE ▓=both")
# Compare predictions
print(f"\n{'─'*64}")
print(f" PREDICTION COMPARISON")
print(f"{'─'*64}")
for prompt in ["the cat sat on", "the dog ran to", "a big cat sat on"]:
show_predictions(standard, tokenizer, prompt, label="[Standard] ")
show_predictions(moe, tokenizer, prompt, label="[MoE ] ")
# Show expert routing (MoE only)
print(f"{'─'*64}")
print(f" EXPERT ROUTING (which experts handle which tokens)")
print(f"{'─'*64}")
for sentence in ["the cat sat on the mat",
"the quick brown fox jumped over the lazy dog"]:
show_expert_routing(moe, tokenizer, sentence)
# Summary
print(f"{'─'*64}")
print(f" SUMMARY")
print(f"{'─'*64}")
print(f"""
Standard Transformer MoE Transformer
───────────────── ─────────────────
{std_params:,} params {moe_params:,} params ({moe_params/std_params:.1f}×)
1 FFN per layer 8 FFNs per layer (2 active)
All tokens, same path Different tokens, different paths
Final loss: {std_history[-1]:.4f} Final loss: {moe_history[-1]:.4f}
The MoE model has {moe_params/std_params:.1f}× more parameters but each token
only uses ~25% of the FFN weights. More knowledge capacity,
similar inference cost.
Look at the routing table above — different tokens naturally
get sent to different experts. This specialization emerged
from training, not from any explicit programming.
""")
Addendum: source code