Feed-Forward Network
After attention blends information between tokens, each token passes through a feed-forward network (FFN) independently. This is a simple 2-layer neural network.
python
1class FeedForward(nn.Module):2 """Position-wise feed-forward network."""3
4 def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.1):5 super().__init__()6 self.linear1 = nn.Linear(embed_dim, ff_dim)7 self.linear2 = nn.Linear(ff_dim, embed_dim)8 self.dropout = nn.Dropout(dropout)9
10 def forward(self, x: torch.Tensor) -> torch.Tensor:11 return self.linear2(self.dropout(F.relu(self.linear1(x))))Why 4x expansion? The FFN temporarily expands to a larger dimension (64→256) to give the network more capacity to learn complex transformations, then contracts back.
Helpful?