Build Your First LLM from ScratchPart 4 · Section 4 of 7
Putting It Together: Single-Head Attention
Everything we've discussed so far—Q, K, V projections, computing attention weights, gathering information—is called single-head attention. It's "single-head" because there's one set of Q, K, V matrices looking at the input in one way.
Think of it as having one perspective on the data. For our calculator, that one perspective might learn: "operations should look at numbers." But what if we want multiple perspectives? That's where multi-head attention comes in (next section). For now, let's see single-head in code:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadAttention(nn.Module):
def __init__(self, embed_dim: int):
"""
Args:
embed_dim: Dimension of input embeddings (64 for our model)
"""
super().__init__()
self.W_q = nn.Linear(embed_dim, embed_dim) # Query projection
self.W_k = nn.Linear(embed_dim, embed_dim) # Key projection
self.W_v = nn.Linear(embed_dim, embed_dim) # Value projection
self.scale = math.sqrt(embed_dim)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, embed_dim]
Returns:
[batch_size, seq_len, embed_dim]
"""
Q = self.W_q(x) # [batch, seq_len, embed_dim]
K = self.W_k(x) # [batch, seq_len, embed_dim]
V = self.W_v(x) # [batch, seq_len, embed_dim]
# Attention scores: Q @ K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# [batch, seq_len, seq_len]
# Convert to probabilities
attention_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attention_weights, V)
# [batch, seq_len, embed_dim]
return outputMapping Code to Concepts
Let's trace through what happens when "two plus three" goes through this code:
Input x: shape [1, 3, 64] (1 example, 3 tokens, 64-dim embeddings)
"two" → [0.8, 0.1, ...]
"plus" → [0.1, 0.9, ...]
"three" → [0.7, 0.2, ...]
Q = W_q(x) → Each token now has a "query" (what it's looking for)
K = W_k(x) → Each token now has a "key" (what it offers)
V = W_v(x) → Each token now has a "value" (its content)
scores = Q @ K.T → 3×3 matrix (each token's query vs all keys)
This is the attention matrix we visualized!
softmax(scores) → Normalize each row to sum to 1
"plus" row: [0.45, 0.10, 0.45]
output = weights @ V → Each token gathers info from others
"plus" now knows it's adding 2 and 3That's it! The entire attention mechanism in ~15 lines. The complexity comes from doing this at scale—GPT-4 uses 12,288-dimensional embeddings instead of 64, which means 450 million parameters just for Q, K, V projections.
Helpful?