Build Your First LLM from ScratchPart 3 · Section 8 of 13
The Embedding Layer
import torch
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, vocab_size: int = 36, embed_dim: int = 64):
super().__init__()
# Create a lookup table: vocab_size rows, embed_dim columns
self.embedding = nn.Embedding(vocab_size, embed_dim)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
# Look up each token ID to get its vector
return self.embedding(token_ids)Usage:
embed = Embedding(vocab_size=36, embed_dim=64)
token_ids = torch.tensor([5, 31, 6]) # "two plus three"
vectors = embed(token_ids)
print(vectors.shape) # torch.Size([3, 64])
# 3 tokens, each represented by 64 numbersHelpful?