Unveiling the Inner Workings: A Deep Dive into BERT’s Attention Mechanism

Nibedita Dutta Last Updated : 02 Jan, 2024
8 min read

Introduction

BERT, short for Bidirectional Encoder Representations from Transformers, is a pre-trained Natural Language Process model developed by Google. It is based on the Transformer architecture, which uses self-attention mechanisms to capture relationships between words in a model. It’s bidirectional because it considers both the right and left context of each word when encoding its representation. Also, BERT has been pre-trained on a large text corpus and can be fine-tuned on specific NLP tasks like classification and question answering. 

BERT was pre-trained on two NLP tasks – Masked Language Model (MLM) and Next Sentence Prediction. In MLM, BERT randomly masks a percentage of the input tokens to predict. This allows the model to learn a more robust language representation because it has to understand the context from both the left and the right side of the masked token. In Next Sentence Prediction, BERT learns to predict if one sentence follows another, helping it understand relationships between sentences.

Also Read: What is BERT? Click here!

Learning Objectives

  • Understanding the attention mechanism in BERT
  • How Tokenization is Done in BERT?
  • How Are Attention Weights Computed in BERT?
  • Python Implementation of a BERT model

This article was published as a part of the Data Science Blogathon.

Attention Mechanism in BERT

Let’s start with understanding what attention means in the simplest terms. Attention is one of the ways by which the model tries to put more weight on those input features that are more important for a sentence.

Let us consider the following examples to understand how attention works fundamentally.

Example 1

BERT's Attention Mechanism
Higher attention given to some words more than other words

In the above sentence, the BERT model may want to put more weightage on the word “cat” and the verb “jumped” than “bag” since knowing them will be more critical for the prediction of the next word “fell” than knowing where the cat jumped from.

Example 2

Consider the following sentence –

Example of higher attention words
Higher attention given to some words more than other words

For predicting the word “spaghetti”, the attention mechanism enables giving more weightage to the verb “eating” rather than the quality “bland” of the spaghetti.

Example 3

Similarly, for a translation task like the following:

Input sentence: How was your day

Target sentence: Comment se passe ta journée

Translation task | BERT's Attention Mechanism
Source : https://blog.floydhub.com/attention-mechanism/

For each word in the output phrase, the attention mechanism will map the significant and pertinent words from the input sentence and give these input words a larger weight. In the above image, notice how the French word ‘Comment’ assigns the highest weightage (represented by dark blue) to the word ‘How,’ and for the word ‘journee,’ the input word ‘day’ receives the highest weightage. This is how the attention mechanism helps attain higher output accuracy by putting more weightage on the words that are more critical for the relevant prediction.

The question that comes to mind is how the model then gives these different weights to the different input words. Let us see in the next section how attention weights enable this mechanism exactly.

Attention Weights For Composite Representations

BERT uses attention weights to process sequences. Consider a sequence X comprising three vectors, each with four elements. The attention function transforms X into a new sequence Y with the same length. Each Y vector is a weighted average of the X vectors, with weights termed attention weights. These weights applied to X’s word embeddings produce composite embeddings in Y.

Attention weights for composite representations

The calculation of each vector in Y relies on varying attention weights assigned to x1, x2, and x3, depending on the required attention for each input feature in generating the corresponding vector in Y. Mathematically speaking, it would looking something as shown –

"

In the above equations, the values 0.4, 0.3 and 0.2 are nothing but the different attention weights assigned to x1, x2 and x3 for computing the composite embeddings y1,y2 and y3. As can be seen, the attention weights assigned to x1,x2 and x3 for computing the composite embeddings are completely different for y1, y2 and y3.

Attention is critical for understanding the context of the sentence as it enables the model to understand how different words are related to each other in addition to understanding the individual words. For example, when a language model tries to predict the next word in the following sentence

“The restless cat was ___”

The model should understand the composite notion of restless cat in addition to understanding the concepts of restless or cat individually; e.g., restless cat often jumps, so jump could be a fair next word in the sentence.

Keys & Query Vectors For Acquiring Attention Weights

By now we know that attention weights help in giving us composite representations of our input words by computation of a weighted average of the inputs with the help of the weights. However, the next question that comes then is where these attention weights come from. The attention weights essentially come from two vectors known by the name of key and query vectors.

BERT measures attention between word pairs using a function that assigns a score to each word pair based on their relationship. It uses query and key vectors as word embeddings to assess compatibility. The compatibility score calculates by taking the dot product of the query vector of one word and the key vector of the other. For instance, it computes the score between ‘jumping’ and ‘cat’ using the dot product of the query vector (q1) of ‘jumping’ and the key vector (k2) of ‘cat’ – q1*k2.

Keys & Query vectors for acquiring attention weights | BERT's Attention Mechanism

To convert compatibility scores to valid attention weights, they need to be normalized. BERT does this by applying the softmax function to these scores, ensuring they are positive and total to one. The resulting values are the final attention weights for each word. Notably, the key and query vectors are computed dynamically from the output of the previous layer, letting BERT adjust its attention mechanism depending on the specific context.

Attention Heads in BERT

BERT learns multiple attention mechanisms which are known as heads. These heads work together at the same time concurrently. Having multiple heads helps BERT understand the relationships between words better than if it only had one head.

BERT splits its Query, Key, and Value parameters N-ways. Each of these N pairs independently passes through a separate Head, performing attention calculations. The results from these pairs are then combined to generate a final Attention score. This is why it is termed ‘Multi-head attention,’ providing BERT with enhanced capability to capture multiple relationships and nuances for each word.

 Multi-head attention in BERT
Multi-head attention

BERT also stacks multiple layers of attention.  Each layer takes the output from the previous layer and pays attention to it. By doing this many times, BERT can create very detailed representations as it goes deeper into the model.

Depending on the specific BERT model, there are either 12 or 24 layers of attention and each layer has either 12 or 16 attention heads. This means that a single BERT model can have up to 384 different attention mechanisms because the weights are not shared between layers.

Python Implementation of a BERT model

Step 1. Import the Necessary Libraries

We would need to import the ‘torch’ python library to be able to use PyTorch. We would also need to import BertTokenizer and BertForSequenceClassification from the transformers library. The tokenizer library helps enable the tokenization of the text while BertForSequenceClassification for text classification.

import torch
from transformers import BertTokenizer, BertForSequenceClassification

Step 2. Load Pre-trained BERT Model and Tokenizer

In this step, we load the “bert-base-uncased” pre-trained model and feed it to the BertForSequenceClassification’s from_pretrained method. Since we want to carry out a simple sentiment classification here, we set num_labels as 2 which represents “positive” and “negative class”.

model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

Step 3. Set Device to GPU if Available

This step is only for switching device to GPU is its available or sticking to CPU.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
#import csv

Step 4. Define the Input Text and Tokenize

In this step, we define the input text for which we want to carry out classification. We also define the tokenizer object which is responsible for converting text into a sequence of tokens, which are the basic units of information that machine learning models can understand. ‘max_length’ parameter sets the maximum length of the tokenized sequence. If the tokenized sequence exceeds this length, the system will truncate it. The parameter ‘padding’ dictates that the tokenized sequence will be padded with zeros to reach the maximum length if it is shorter.The parameter “truncation” indicates whether to truncate the tokenized sequence if it exceeds the maximum length.

Since this parameter is set to True, the sequence will be truncated if necessary. The parameter “return_tensors” specifies the format in which to return the tokenized sequence. In this case, the function returns the sequence as a PyTorch tensor. It then moves the ‘input_ids’ and ‘attention_mask’ of the generated tokens to the specified device. The attention mask, previously discussed, is a binary tensor that indicates which parts of the input sequence to attend more to for a specific prediction task.

text = "I did not really enjoyed this movie. It was fantastic!"
#Tokenize the input text
tokens = tokenizer.encode_plus(
    text,
    max_length=128,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)
# Move input tensors to the device
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
#import csv

Step 5.  Perform Sentiment Prediction

In the next step, the model generates the prediction for the given input_ids and attention_mask.

with torch.no_grad():
    outputs = model(input_ids, attention_mask)
predicted_label = torch.argmax(outputs.logits, dim=1).item()
sentiment = 'positive' if predicted_label == 1 else 'negative'
print(f"The sentiment of the input text is {sentiment}.")
#import csv

Output

The sentiment of the input text is Positive.

Conclusion

This article covered attention in BERT, highlighting its importance in understanding sentence context and word relationships. We explored attention weights, which give composite representations of input words through weighted averages. The computation of these weights involves key and query vectors. BERT determines the compatibility score between two words by taking the dot product of these vectors. This process, known as “heads”, is BERT’s way of focusing on words. Multiple attention heads enhance BERT’s understanding of word relationships. Finally, we looked into the python implementation of a pretrained BERT model.

Key Takeaways

  • BERT is based on the Transformer architecture, which uses self-attention mechanisms to capture relationships between words in a model. It’s bidirectional because it considers each word’s right and left context when encoding its representation.
  • BERT uses ‘attention’ to prioritize relevant input features in sentences, aiding in understanding word relationships and contexts.
  • Attention weights calculate a weighted average of inputs for composite representations. The use of multiple attention heads and layers allows BERT to create detailed word representations by focusing on previous layer outputs.

Frequently Asked Questions

Q1. What is BERT?

A. BERT, short for Bidirectional Encoder Representations from Transformers, is based on the Transformer architecture, which uses self-attention mechanisms to capture relationships between words in a model. It is bidirectional because it considers both each word’s right and left context when encoding its representation.

Q2. Is the BERT model pretrained?

A. It undergoes pretraining, learning beforehand through two unsupervised tasks: masked language modeling and next-sentence prediction.

Q3. What are the application areas of BERT models?

A. Use BERT models for a variety of applications in NLP including but not limited to text classification, sentiment analysis, question answering, text summarization, machine translation, spell checking and grammar checking, content recommendation.

Q4. What is the meaning of attention in BERT?

A. Attention is one of the ways by which the BERT model tries to put more weight on those input features that are more important for a sentence. Attention is critical for understanding the context of the sentence as it enables the model to understand how different words are related to each other in addition to understanding the individual words.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

Nibedita completed her master’s in Chemical Engineering from IIT Kharagpur in 2014 and is currently working as a Senior Data Scientist. In her current capacity, she works on building intelligent ML-based solutions to improve business processes.

Responses From Readers

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details