Build Your First LLM from ScratchPart 4 · Section 5 of 7

Multi-Head Attention

Why Multiple Heads?

One attention head learns one type of pattern. But language has many patterns:

Single head might learn: "operations look at numbers"

But we also need:
- Position-based patterns
- Semantic similarity
- Syntactic relationships
- Many patterns we can't name

The Solution: Multiple Heads in Parallel

      Input embeddings
             |
    +--------+--------+
    |        |        |
    v        v        v        v
  Head 1   Head 2   Head 3   Head 4
    |        |        |        |
    +--------+--------+--------+
             |
         Concatenate
             |
       Linear projection
             |
       Combined output

How It Works

Instead of one 64-dim attention, we use 4 heads of 16-dim each:

embed_dim = 64
num_heads = 4
head_dim = embed_dim / num_heads = 64 / 4 = 16

Each head:
- Q, K, V are 16-dimensional (not 64)
- Learns its own patterns
- Output is 16-dimensional

Concatenate 4 heads: 4 × 16 = 64 dimensions
Final linear layer: 64 → 64

Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)  # Output projection

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        Q = self.W_q(x)  # [batch, seq_len, embed_dim]
        K = self.W_k(x)
        V = self.W_v(x)

        # Split into heads: [batch, seq_len, num_heads, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose for attention: [batch, num_heads, seq_len, head_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attention = F.softmax(scores, dim=-1)
        out = torch.matmul(attention, V)

        # Concatenate heads: [batch, seq_len, embed_dim]
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        # Final projection
        return self.W_o(out)

Our Model vs. At Scale

ModelEmbed DimNum HeadsHead Dim
Our Calculator64416
GPT-2 Small7681264
GPT-312,28896128
GPT-4~12,288~96~128

Same pattern: more heads, larger dimensions.

Helpful?