Implementing Multi-Head Attention from Scratch using TensorFlow and Keras
In this tutorial, we will implement Multi-Head Attention using Tensorflow and Keras libraries in Python.
Let’s first understand what we mean by Multi-Head Attention Transformer.
What is Transformer Multi-Head Attention?
Let’s first understand what we mean by transformer:
A transformer is a model in deep learning which is mainly used in fields like NLP and computer vision. It has a mechanism of self-attention
and also the significance of each part of the input data weighs differently.
The Transformer Multi-Head Attention
Multi-Head Attention module is a module for attention mechanism which runs various times in parallel through an attention mechanism.
The attention module contains these three parameters Query, Key, and Value. The Attention module splits its Query, Key, and Value parameters N-ways and passes each split independently through a separate Head. To produce a final Attention score all similar attention scores are combined together.
Implementing Multi-Head Attention from Scratch
Let’s start installing the Python libraries:
pip install tensorflow
pip install keras
Let’s start implementing Multi-Head attention now:
Let’s now start importing libraries
from tensorflow import matmul, math, cast, float32 from tensorflow.keras.layers import Layer from keras.backend import softmax from numpy import random
After importing all the libraries let’s create a class for scaled-dot product attention:
class DotProduct(Layer): def __init__(self, **kwargs): super(DotProduct, self).__init__(**kwargs) def calling(self, query, key, value, k, mask=None): scores = multi(query, key, transpose_1=True) / math.sqrt(cast(k, float32)) if mask is not None: attention_scores += -1e9 * mask weights = softmax(attention_scores) return multi(weights, values)
Now, let’s build a class for Multi-head attention:
class MultiHead(Layer): def __init__(self, h, k, v, model, **kwargs): super(MultiHead, self).__init__(**kwargs) self.attention = DotProduct() self.heads = h self.k = k self.v = v self.model = model self.W_q = Dense(k) self.W_k = Dense(k) self.W_v = Dense(v) self.W_o = Dense(model) def reshape_tensor(self, x, head, flag): if flag: x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1)) x = transpose(x, perm=(0, 2, 1, 3)) else: x = transpose(x, perm=(0, 2, 1, 3)) x = reshape(x, shape=(shape(x)[0], shape(x)[1], self.k)) return x def calling(self, query, key, value, mask=None): reshaped_query = self.reshape_tensor(self.W_q(query), self.head, True) reshaped_key = self.reshape_tensor(self.W_k(key), self.head, True) reshaped_value = self.reshape_tensor(self.W_v(value), self.head, True) reshaped_output = self.attention(reshaped_query, reshaped_key, reshaped_value, self.k, mask) output = self.reshape_tensor(reshaped_output, self.head, False) return self.W_o(output)
Now, let’s just test by assigning values:
sequence_length = 3 h = 10 k = 54 v = 54 model = 125 batch_size = 54 query = random.random((batch_size, sequence_length, k)) key = random.random((batch_size, sequence_length, k)) value = random.random((batch_size, sequence_length, v)) final_attention = MultiHead(h, k, v, model) print(final_attention(query, key, value))
Output:
Leave a Reply