Build Your First LLM from ScratchPart 4 · Section 6 of 7
Masked Attention (Causal Masking)
The Problem
During training for generation, we have the full sequence:
Input: "two plus three"
Target: "five"
Full sequence: [START] two plus three [END] fiveBut during generation, we can only see previous tokens:
Step 1: See [START] two plus three [END], predict "five"
Step 2: See [START] two plus three [END] five, predict next...If we let the model see future tokens during training, it cheats!
The Solution: Mask Future Tokens
Attention matrix for "[START] five [END]":
START five END
START [ 1.0 0.0 0.0 ] ← Can only see itself
five [ 0.5 0.5 0.0 ] ← Can see START and itself
END [ 0.3 0.4 0.3 ] ← Can see everything before
The upper triangle is masked (set to -infinity before softmax)Implementation
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""Create mask that prevents attending to future tokens."""
# Upper triangle = True (will be masked)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
return mask
# In the attention forward method:
def forward(self, x, mask=None):
# ... compute scores ...
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attention = F.softmax(scores, dim=-1)
# -inf becomes 0 after softmaxWhen to Use Masking
| Architecture | Masking | Use Case |
|---|---|---|
| GPT (decoder-only) | Causal mask | Text generation - only sees past |
| BERT (encoder-only) | No mask | Understanding - sees all tokens |
| Our Calculator | Encoder: no mask, Decoder: causal mask | Encoder-decoder architecture |
Causal masking is critical for autoregressive generation (GPT-style models).
Helpful?