Build Your First LLM from ScratchPart 4 · Section 4 of 7

Putting It Together: Single-Head Attention

Everything we've discussed so far—Q, K, V projections, computing attention weights, gathering information—is called single-head attention. It's "single-head" because there's one set of Q, K, V matrices looking at the input in one way.

Think of it as having one perspective on the data. For our calculator, that one perspective might learn: "operations should look at numbers." But what if we want multiple perspectives? That's where multi-head attention comes in (next section). For now, let's see single-head in code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, embed_dim: int):
        """
        Args:
            embed_dim: Dimension of input embeddings (64 for our model)
        """
        super().__init__()
        self.W_q = nn.Linear(embed_dim, embed_dim)  # Query projection
        self.W_k = nn.Linear(embed_dim, embed_dim)  # Key projection
        self.W_v = nn.Linear(embed_dim, embed_dim)  # Value projection
        self.scale = math.sqrt(embed_dim)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, embed_dim]
        Returns:
            [batch_size, seq_len, embed_dim]
        """
        Q = self.W_q(x)  # [batch, seq_len, embed_dim]
        K = self.W_k(x)  # [batch, seq_len, embed_dim]
        V = self.W_v(x)  # [batch, seq_len, embed_dim]

        # Attention scores: Q @ K^T
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # [batch, seq_len, seq_len]

        # Convert to probabilities
        attention_weights = F.softmax(scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(attention_weights, V)
        # [batch, seq_len, embed_dim]

        return output

Mapping Code to Concepts

Let's trace through what happens when "two plus three" goes through this code:

Input x: shape [1, 3, 64]  (1 example, 3 tokens, 64-dim embeddings)
         "two"   → [0.8, 0.1, ...]
         "plus"  → [0.1, 0.9, ...]
         "three" → [0.7, 0.2, ...]

Q = W_q(x)  →  Each token now has a "query" (what it's looking for)
K = W_k(x)  →  Each token now has a "key" (what it offers)
V = W_v(x)  →  Each token now has a "value" (its content)

scores = Q @ K.T  →  3×3 matrix (each token's query vs all keys)
                     This is the attention matrix we visualized!

softmax(scores)   →  Normalize each row to sum to 1
                     "plus" row: [0.45, 0.10, 0.45]

output = weights @ V  →  Each token gathers info from others
                         "plus" now knows it's adding 2 and 3
That's it! The entire attention mechanism in ~15 lines. The complexity comes from doing this at scale—GPT-4 uses 12,288-dimensional embeddings instead of 64, which means 450 million parameters just for Q, K, V projections.
Helpful?