Multi-Head Attention
llms
In this post, I implement Multi-Head Attention in 3 stages, following Sebastian Raschka's excellent instruction in Chapter 16 of his text on Machine Learning and Deep Learning.
The 3 stages are:
- Basic Form of Attention - really easy to understand, but not used in practice. A good starting point to understand that attention is just a weighted average of the input vectors, where the weights are determined by the similarity of the input vectors. In linear algebra terms, this is a linear transformation of the input vectors, where the transformation matrix is the similarity matrix.
- Parameterized Attention - this introduces trainable parameters in the form of projection matrices query, key, and value. This is the form of self-attention used in the Transformer architecture.
- Multi-Headed Attention - this introduces "heads" which are parallel attention mechanisms and is analogous to channels in a convolutional neural network.
In all cases, this is also a good exercise in matrix multiplication and broadcasting.
0. Imports & Setup¶
In [1]:
import torch
import torch.nn.functional as F
torch.manual_seed(123)
Out[1]:
<torch._C.Generator at 0x12516e690>
Create the matrix of word embeddings to be used as input to the attention mechanism. This has dimensions (sequence_length, embedding_dimension).
In [2]:
d_sentence = 12 #number of words in the sentence
d_embedding = 14 # embedding length
sentence = torch.randperm(d_sentence)
embeddings = torch.nn.Embedding(d_sentence, d_embedding)
embedding_sentence = embeddings(sentence).detach()
embedding_sentence.shape
Out[2]:
torch.Size([12, 14])
1. Basic Form of Self-Attention¶
In [3]:
# distance matrix - omega_ij represents the distance between word i and word j
omega = embedding_sentence.matmul(embedding_sentence.T)
In [4]:
# attention weights - normalization of the distance matrix
attention_weights = F.softmax(omega, dim=1)
attention_weights.shape
Out[4]:
torch.Size([12, 12])
In [5]:
# context vector - weighted sum of the embedding vectors
context_vector = attention_weights.matmul(embedding_sentence)
context_vector.shape
Out[5]:
torch.Size([12, 14])
2. Parameterized Self-Attention¶
In [6]:
# initialize projection matrices
d_query = d_key = d_value = 10 # dimension of query, key and value projection matrices
U_query = torch.rand(d_query, d_embedding)
U_key = torch.rand(d_key, d_embedding)
U_value = torch.rand(d_value, d_embedding)
In [7]:
query = U_query.matmul(embedding_sentence.T).T
key = U_key.matmul(embedding_sentence.T).T
values = U_value.matmul(embedding_sentence.T).T
query.shape, key.shape, values.shape #d_sentence x d_query
Out[7]:
(torch.Size([12, 10]), torch.Size([12, 10]), torch.Size([12, 10]))
In [8]:
# distance between projected vectors
omega = query.matmul(key.T) #d_sentence x d_sentence:
In [9]:
# normalization
attention_weights = F.softmax(omega / d_query**0.5, dim=1)
In [10]:
# context vector - weighted sum of the projected value vectors
context_vector = attention_weights.matmul(values)
context_vector.shape #d_sentence x d_value
Out[10]:
torch.Size([12, 10])
3. Multi-Head Attention¶
The primary challenge is to multiply in such a way to perform the head dimension in parallel with everything as above constant. The other difference is the use of a linear layer to collapse the head dimension into the embedding dimension.
In [11]:
# initialize projection matrices
h = 8 # number of heads
multihead_U_query = torch.rand(h, d_query, d_embedding)
multihead_U_key = torch.rand(h, d_key, d_embedding)
multihead_U_value = torch.rand(h, d_value, d_embedding)
In [12]:
multihead_query = multihead_U_query.matmul(embedding_sentence.T).transpose(2,1)
multihead_key = multihead_U_key.matmul(embedding_sentence.T).transpose(2,1)
multihead_values = multihead_U_value.matmul(embedding_sentence.T).transpose(2,1)
multihead_query.shape, multihead_key.shape, multihead_values.shape #h x d_sentence x d_query
Out[12]:
(torch.Size([8, 12, 10]), torch.Size([8, 12, 10]), torch.Size([8, 12, 10]))
In [13]:
# distance between projected vectors for each of the 8 heads
omega = multihead_query.matmul(multihead_key.transpose(2,1))
omega.shape #h x d_sentence x d_sentence
Out[13]:
torch.Size([8, 12, 12])
In [14]:
# normalization
attention_weights = F.softmax(omega / d_query**0.5, dim=2)
attention_weights.shape #h x d_sentence x d_sentence
Out[14]:
torch.Size([8, 12, 12])
In [15]:
# context vector - with each head separate
context_vector = attention_weights.matmul(multihead_values)
context_vector.shape # h x d_sentence x d_value
Out[15]:
torch.Size([8, 12, 10])
In [16]:
# adding a linear layer to combine the heads for each word in the sentence
linear = torch.nn.Linear(h*d_value, d_embedding)
In [17]:
# flatten the first two dimensions of the tensor
context_vector = context_vector.view(-1, h*d_value)
context_vector.shape # d_sentence x (h*d_value)
Out[17]:
torch.Size([12, 80])
In [18]:
context_vector_linear = linear(context_vector)
context_vector_linear.shape # d_sentence x d_embedding
Out[18]:
torch.Size([12, 14])