Generation
Your model predicts one token at a time. Generation means feeding output back as input, over and over.
python
1def generate(2 model: CalculatorLLM,3 tokenizer: Tokenizer,4 prompt: str,5 max_new_tokens: int = 10,6 device: str = "cpu",7) -> str:8 """Generate text from a prompt using greedy decoding."""9 model.eval()10
11 # Encode prompt (without end token so we can continue generating)12 tokens = tokenizer.encode(prompt, add_special_tokens=True)[:-1]13 input_ids = torch.tensor([tokens]).to(device)14
15 with torch.no_grad():16 for _ in range(max_new_tokens):17 logits = model(input_ids)18 next_token = logits[0, -1, :].argmax().item()19
20 if next_token == tokenizer.end_token_id:21 break22
23 input_ids = torch.cat(24 [input_ids, torch.tensor([[next_token]]).to(device)], dim=125 )26
27 return tokenizer.decode(input_ids[0].tolist())28
29
30def solve(31 model: CalculatorLLM,32 tokenizer: Tokenizer,33 problem: str,34 device: str = "cpu",35) -> str:36 """Solve an English math problem."""37 # Normalize and ensure it ends with "equals"38 problem = problem.lower().strip()39 if not problem.endswith("equals"):40 problem = problem + " equals"41
42 result = generate(model, tokenizer, problem, device=device)43
44 # Extract just the answer after "equals"45 if "equals" in result:46 return result.split("equals")[-1].strip()47 return result48
49
50# Try it!51print(solve(model, tokenizer, "two plus three")) # → "five"52print(solve(model, tokenizer, "seven times six")) # → "forty two"| Temperature | Effect | Use Case |
|---|---|---|
| 0 (greedy) | Deterministic | Math, code, facts |
| 0.7-1.0 | Balanced | General conversation |
| 1.5+ | Creative | Brainstorming |
Tests
python
1# tests/test_generate.py2def test_generate_returns_string(model, tokenizer):3 result = generate(model, tokenizer, "two plus three equals")4 assert isinstance(result, str)5
6def test_solve_returns_string(model, tokenizer):7 result = solve(model, tokenizer, "two plus three")8 assert isinstance(result, str)9
10def test_solve_handles_uppercase(model, tokenizer):11 result = solve(model, tokenizer, "TWO PLUS THREE")12 assert result is not None13
14def test_evaluate_returns_accuracy_and_errors(model, tokenizer):15 test_data = [{"input": "two plus three", "output": "five"}]16 accuracy, errors = evaluate_model(model, tokenizer, test_data)17 assert 0 <= accuracy <= 118 assert isinstance(errors, list)Run tests: pytest tests/test_generate.py -v
Helpful?