Implementing Multi-Head Attention from Scratch using TensorFlow and Keras

In this tutorial, we will implement Multi-Head Attention using Tensorflow and Keras libraries in Python.

Let’s first understand what we mean by Multi-Head Attention Transformer.

What is Transformer Multi-Head Attention?

Let’s first understand what we mean by transformer:

A transformer is a model in deep learning which is mainly used in fields like NLP and computer vision. It has a mechanism of self-attention

and also the significance of each part of the input data weighs differently.

The Transformer Multi-Head Attention

Multi-Head Attention module is a module for attention mechanism which runs various times in parallel through an attention mechanism.

The attention module contains these three parameters Query, Key, and Value. The Attention module splits its Query, Key, and Value parameters N-ways and passes each split independently through a separate Head. To produce a final Attention score all similar attention scores are combined together.

Implementing Multi-Head Attention from Scratch

Let’s start installing the Python libraries:

pip install tensorflow
pip install keras

Let’s start implementing Multi-Head attention now:

Let’s now start importing libraries

from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from keras.backend import softmax
from numpy import random

After importing all the libraries let’s create a class for scaled-dot product attention:

class DotProduct(Layer):
    def __init__(self, **kwargs):
        super(DotProduct, self).__init__(**kwargs)

    def calling(self, query, key, value, k, mask=None):
        
        scores = multi(query, key, transpose_1=True) / math.sqrt(cast(k, float32))

        if mask is not None:
            attention_scores += -1e9 * mask

        weights = softmax(attention_scores)

        return multi(weights, values)

Now, let’s build a class for Multi-head attention:

class MultiHead(Layer):
    def __init__(self, h, k, v, model, **kwargs):
        super(MultiHead, self).__init__(**kwargs)
        self.attention = DotProduct() 
        self.heads = h 
        self.k = k  
        self.v = v  
        self.model = model  
        self.W_q = Dense(k)  
        self.W_k = Dense(k)  
        self.W_v = Dense(v)  
        self.W_o = Dense(model) 

    def reshape_tensor(self, x, head, flag):
        if flag:
            
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
            x = transpose(x, perm=(0, 2, 1, 3))
        else:
           
            x = transpose(x, perm=(0, 2, 1, 3))
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], self.k))
        return x

    def calling(self, query, key, value, mask=None):
       
        reshaped_query = self.reshape_tensor(self.W_q(query), self.head, True)
       
        reshaped_key = self.reshape_tensor(self.W_k(key), self.head, True)
       
        reshaped_value = self.reshape_tensor(self.W_v(value), self.head, True)
        
        reshaped_output = self.attention(reshaped_query, reshaped_key, reshaped_value, self.k, mask)
      
        output = self.reshape_tensor(reshaped_output, self.head, False)
       
        return self.W_o(output)

Now, let’s just test by assigning values:

sequence_length = 3 
h = 10 
k = 54  
v = 54  
model = 125  
batch_size = 54
 
query = random.random((batch_size, sequence_length, k))
key = random.random((batch_size, sequence_length, k))
value = random.random((batch_size, sequence_length, v))
 
final_attention = MultiHead(h, k, v, model)
print(final_attention(query, key, value))

Output:

Leave a Reply

Your email address will not be published. Required fields are marked *