In the previous post, we learned about the basics of Transformer models and how they’re used to process language. Now, let’s go a bit deeper and look at how to build each important part of a Transformer. By the end of this post, we’ll have a full, working Transformer model built in Python.
Multi-Head Self-Attention
Self-attention allows the model to figure out which words in a sentence are related to each other. Multi-head self-attention improves on this by looking at different relationships between words at the same time, giving the model more “views” on the input.
What is Multi-Head Self-Attention?
- We split the word’s information into multiple parts (heads).
- Each head focuses on different word relationships.
- All heads are combined at the end to give a richer understanding of the sentence.
Code for Multi-Head Self-Attention:
import torch
import torch.nn as nn
# Multi-Head Self-Attention Layer
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads # Split the embedding size across heads
# Linear layers to convert inputs to 'queries', 'keys', and 'values'
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split embeddings into multiple heads
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# Calculate the attention score (how words relate)
attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Softmax to normalize the scores
attention = torch.softmax(attention / (self.embed_size ** 0.5), dim=3)
# Weighted sum of values based on attention scores
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# Final linear layer to process the output
out = self.fc_out(out)
return out
What’s happening here:
- Values, Keys, and Queries: The model splits the input data into parts that it uses to decide which words should focus on each other.
- Multiple Heads: It uses multiple heads, each focusing on different relationships.
- Output: All heads are combined into a final output, which is a richer representation of the input.
Feedforward Neural Network
After the self-attention layer, the model processes each word with a simple feedforward network. This is a bit of extra processing to help the model make better decisions.
Code for Feedforward Network:
# Feedforward Network
class FeedForward(nn.Module):
def __init__(self, embed_size, forward_expansion):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(embed_size, forward_expansion * embed_size)
self.fc2 = nn.Linear(forward_expansion * embed_size, embed_size)
def forward(self, x):
x = torch.relu(self.fc1(x)) # Apply a ReLU function to add complexity
return self.fc2(x)
What’s happening here:
- We take the word’s information and pass it through two layers of math operations.
- The ReLU function helps the model make more complex decisions.
Positional Encoding
Since Transformers look at the whole sentence at once, they need a way to understand the order of words. This is done with positional encoding, which adds information about the position of each word in the sentence.
Code for Positional Encoding:
import math
# Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, embed_size)
for pos in range(max_len):
for i in range(0, embed_size, 2):
self.encoding[pos, i] = math.sin(pos / (10000 ** (2 * i / embed_size)))
self.encoding[pos, i + 1] = math.cos(pos / (10000 ** (2 * i / embed_size)))
def forward(self, x):
seq_len = x.size(1)
return x + self.encoding[:seq_len, :].to(x.device)
What’s happening here:
- Sine and Cosine Functions: The model uses mathematical functions (sine and cosine) to encode each word’s position in the sentence.
- The positional information is added to the input data so the model knows the order of words.
Layer Normalization and Residual Connections
To keep the model stable during training, we use Layer Normalization to scale the data, and residual connections to make sure the model remembers the original input.
Code for Layer Normalization and Residual Connections:
# Transformer Block with Layer Normalization and Residual Connections
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, forward_expansion, dropout):
super(TransformerBlock, self).__init__()
# Multi-head attention
self.attention = MultiHeadSelfAttention(embed_size, heads)
# Feedforward network
self.feed_forward = FeedForward(embed_size, forward_expansion)
# Layer normalization
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
# Dropout to prevent overfitting
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query):
# Apply multi-head attention and add residual connection
attention = self.attention(value, key, query)
x = self.norm1(attention + query)
# Apply feedforward and another residual connection
forward = self.feed_forward(x)
out = self.norm2(forward + x)
return out
What’s happening here:
- Layer Normalization helps keep the model stable by normalizing the values during training.
- Residual Connections ensure that the model doesn’t forget the original input as it processes the data.
Building the Full Transformer Model
Now we put all the pieces together! We’ll stack multiple Transformer blocks and add a layer at the end to predict the next word in the sequence.
Code for Full Transformer Model:
# Full Transformer Model
class Transformer(nn.Module):
def __init__(self, embed_size, heads, depth, forward_expansion, max_len, dropout, vocab_size):
super(Transformer, self).__init__()
# Embedding layer to convert words to dense vectors
self.embedding = nn.Embedding(vocab_size, embed_size)
# Positional encoding to remember word order
self.pos_encoding = PositionalEncoding(embed_size, max_len)
# Stack multiple Transformer blocks
self.layers = nn.ModuleList(
[TransformerBlock(embed_size, heads, forward_expansion, dropout) for _ in range(depth)]
)
# Final output layer to predict the next word
self.fc_out = nn.Linear(embed_size, vocab_size)
def forward(self, x):
# Convert input words to embeddings
x = self.embedding(x)
# Add positional encodings
x = self.pos_encoding(x)
# Pass through each Transformer block
for layer in self.layers:
x = layer(x, x, x)
# Predict the output
out = self.fc_out(x)
return out
Explanation of Output
Let’s break down what happens when we give the model a sentence:
- Input Sentence (Word Indices): The input sentence is just numbers representing words in a vocabulary. For example:
tensor([[4, 1, 2, 9, 8]])
Each number represents a word. - Embedding: These word indices are converted into dense vectors, which are easier for the model to understand.
- Positional Encoding: The model adds positional information so it knows the order of the words in the sentence.
- Transformer Blocks: The input goes through several Transformer blocks where the model applies multi-head self-attention and processes the words through feedforward networks.
- Output: The final output is a list of scores for each word, which can be used to predict the next word or classify the sentence.
Conclusion
In this simplified post, we built the core components of a Transformer:
- Multi-Head Self-Attention allows the model to look at different word relationships simultaneously.
- Feedforward Networks process the words further.
- Positional Encoding helps the model keep track of word order.
- Layer Normalization and Residual Connections make sure the model is stable and doesn’t lose important information.
In the next post, we’ll train this model and see it in action!
0 Comments