Build Your First LLM from ScratchPart 4 · Section 6 of 7

Masked Attention (Causal Masking)

The Problem

During training for generation, we have the full sequence:

Input:  "two plus three"
Target: "five"

Full sequence: [START] two plus three [END] five

But during generation, we can only see previous tokens:

Step 1: See [START] two plus three [END], predict "five"
Step 2: See [START] two plus three [END] five, predict next...

If we let the model see future tokens during training, it cheats!

The Solution: Mask Future Tokens

Attention matrix for "[START] five [END]":

          START  five   END
START  [  1.0    0.0    0.0  ]  ← Can only see itself
five   [  0.5    0.5    0.0  ]  ← Can see START and itself
END    [  0.3    0.4    0.3  ]  ← Can see everything before

The upper triangle is masked (set to -infinity before softmax)

Implementation

def create_causal_mask(seq_len: int) -> torch.Tensor:
    """Create mask that prevents attending to future tokens."""
    # Upper triangle = True (will be masked)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

# In the attention forward method:
def forward(self, x, mask=None):
    # ... compute scores ...

    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))

    attention = F.softmax(scores, dim=-1)
    # -inf becomes 0 after softmax

When to Use Masking

ArchitectureMaskingUse Case
GPT (decoder-only)Causal maskText generation - only sees past
BERT (encoder-only)No maskUnderstanding - sees all tokens
Our CalculatorEncoder: no mask, Decoder: causal maskEncoder-decoder architecture
Causal masking is critical for autoregressive generation (GPT-style models).
Helpful?