Transformer Block

A transformer block combines attention (tokens share info) with feed-forward (each token processed). We add layer normalization and residual connections for stable training.

python
1class TransformerBlock(nn.Module):
2 """A single transformer decoder block."""
3
4 def __init__(
5 self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1
6 ):
7 super().__init__()
8 self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
9 self.norm1 = nn.LayerNorm(embed_dim)
10 self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
11 self.norm2 = nn.LayerNorm(embed_dim)
12 self.dropout = nn.Dropout(dropout)
13
14 def forward(
15 self, x: torch.Tensor, mask: torch.Tensor | None = None
16 ) -> tuple[torch.Tensor, torch.Tensor]:
17 attn_output, attn_weights = self.attention(x, mask)
18 x = self.norm1(x + self.dropout(attn_output))
19 ff_output = self.feed_forward(x)
20 x = self.norm2(x + self.dropout(ff_output))
21 return x, attn_weights

This is the Post-LN (Post-LayerNorm) variant from the original Attention Is All You Need paper, which applies layer normalization after each sub-layer.

Helpful?