Build Your First LLM from ScratchPart 4 · Section 5 of 7
Multi-Head Attention
Why Multiple Heads?
One attention head learns one type of pattern. But language has many patterns:
Single head might learn: "operations look at numbers"
But we also need:
- Position-based patterns
- Semantic similarity
- Syntactic relationships
- Many patterns we can't nameThe Solution: Multiple Heads in Parallel
Input embeddings
|
+--------+--------+
| | |
v v v v
Head 1 Head 2 Head 3 Head 4
| | | |
+--------+--------+--------+
|
Concatenate
|
Linear projection
|
Combined outputHow It Works
Instead of one 64-dim attention, we use 4 heads of 16-dim each:
embed_dim = 64
num_heads = 4
head_dim = embed_dim / num_heads = 64 / 4 = 16
Each head:
- Q, K, V are 16-dimensional (not 64)
- Learns its own patterns
- Output is 16-dimensional
Concatenate 4 heads: 4 × 16 = 64 dimensions
Final linear layer: 64 → 64Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim) # Output projection
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
Q = self.W_q(x) # [batch, seq_len, embed_dim]
K = self.W_k(x)
V = self.W_v(x)
# Split into heads: [batch, seq_len, num_heads, head_dim]
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose for attention: [batch, num_heads, seq_len, head_dim]
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attention = F.softmax(scores, dim=-1)
out = torch.matmul(attention, V)
# Concatenate heads: [batch, seq_len, embed_dim]
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# Final projection
return self.W_o(out)Our Model vs. At Scale
| Model | Embed Dim | Num Heads | Head Dim |
|---|---|---|---|
| Our Calculator | 64 | 4 | 16 |
| GPT-2 Small | 768 | 12 | 64 |
| GPT-3 | 12,288 | 96 | 128 |
| GPT-4 | ~12,288 | ~96 | ~128 |
Same pattern: more heads, larger dimensions.
Helpful?