How It Learns

Loading...
The whole LLM factory floor showing the training loop: raw text input goes through tokenizer, embeddings, positional encoding, then the LLM transformer model. The output is compared against the target using a loss scale, blame flows backward, weights are updated, and the cycle repeats until the model learns.
The Training Loop: How the factory learns from mistakes

We've seen how data flows through the model: text → tokens → embeddings → attention → output. But there's a critical question we haven't answered:

How does the model know that "two plus three" should output "five"?

The answer: it learns from examples. Thousands of them. This process is called training.

The Learning Problem

When we first create a model, all its weights (the numbers inside) are random. It knows nothing. Ask it "two plus three" and it might say "seventy".

Before training:
Input:  "two plus three equals"
Output: "seventy"  ← Random guess (wrong!)

After training:
Input:  "two plus three equals"
Output: "five"     ← Learned the pattern!

Training is the process of adjusting those random weights until the model gives correct answers.

The Training Task: Next-Token Prediction

Here's the core idea: we show the model incomplete sequences and ask it to predict what comes next.

Training example:
Input:  "two plus three equals ___"
Target: "five"

The model sees everything up to the blank,
then tries to predict the next word.

This is called next-token prediction. It's the same task used to train GPT, Claude, and every modern LLM. The only difference is scale — GPT-4 trained on billions of examples, we'll train on thousands.

Loss: Measuring How Wrong We Are

After the model makes a prediction, we need to measure how wrong it was. This measurement is called loss.

Model predicts: "four" (47%), "five" (42%), "six" (11%)
Correct answer: "five"

Loss = How surprised we are that "five" wasn't #1
     = -log(0.42) = 0.87

Lower loss = better prediction
Loss of 0 = perfect (100% confidence in right answer)

The specific formula is called cross-entropy loss. You don't need to memorize it — just know that:

  • High loss = model is confidently wrong, or uncertain
  • Low loss = model is confidently correct
  • Goal = minimize loss across all training examples

Backpropagation: Tracing the Blame

Now we know how wrong the model was. But which weights caused the error? We need to trace the blame backwards through the network.

Think of our calculator pipeline. If the model predicted "four" instead of "five":

"two plus three" → Tokenizer → Embeddings → Attention → Output → "four" (wrong!)
                              ↑          ↑           ↑         ↑
                           5% blame   15% blame   30% blame  50% blame

Which stage caused the error?
- Maybe the attention weights didn't connect "plus" to the numbers correctly
- Maybe the output layer mapped to the wrong word
- Backpropagation figures out exactly how much each weight contributed

Backpropagation is the algorithm that calculates how much each weight contributed to the error. It works backwards from the output, using calculus (specifically, the chain rule) to assign blame to every weight in the network.

You don't need to implement backpropagation. PyTorch does it automatically with loss.backward(). But understanding the concept helps you debug training issues.

Gradient Descent: Taking a Step

Once we know how much each weight is to blame, we adjust them slightly. The amount we adjust is called the gradient — it tells us the direction and magnitude of change needed.

weight_old = 0.5
gradient = 0.1      (positive = weight is too high)
learning_rate = 0.01

weight_new = weight_old - learning_rate × gradient
           = 0.5 - 0.01 × 0.1
           = 0.499

Tiny adjustment, but multiply by millions of weights
and thousands of examples = big change over time.

This is gradient descent: repeatedly taking small steps downhill on the "loss landscape" until we reach a low point (good model).

The Training Loop

Put it all together, and training is just a loop:

for each example in training_data:
    1. Forward pass: "two plus three equals" → model → "four" (35%)
    2. Compute loss: Target was "five", loss = 1.05 (pretty wrong)
    3. Backward pass: Attention got 30% blame, output got 50% blame...
    4. Update weights: Nudge them slightly toward "five"

After 1000 examples:
    "two plus three equals" → model → "five" (94%)
    Loss = 0.06 (much better!)

That's it. Every LLM — from our tiny calculator to GPT-4 — learns this way. The difference is just:

  • Model size — Our calculator: ~100K weights. GPT-4: ~1.7 trillion
  • Training data — Our calculator: ~15K examples. GPT-4: trillions of tokens
  • Training time — Our calculator: minutes. GPT-4: months on thousands of GPUs
Same algorithm, different scale. The training loop you'll write in Part 5 is fundamentally identical to what trained GPT-4. Understanding it at our small scale means you understand it at any scale.

What the Model Actually Learns

Through training, the model discovers patterns:

  • Attention patterns — "When I see 'plus', look at the numbers before and after it"
  • Number representations — "'two' and 'three' are small numbers, 'ninety' is large"
  • Operation semantics — "'plus' means add, 'minus' means subtract, 'times' means multiply"
  • Output mapping — "'two plus three' should produce 'five'"

None of this is programmed explicitly. The model discovers these patterns by trying to minimize loss on thousands of examples.

Learn More

Helpful?