Vision Transformers (ViT): Revolutionizing Computer Vision

Premanand S Last Updated : 15 Jun, 2023
12 min read

Introduction

Vision Transformers (ViT) have emerged as a revolutionary approach in the field of computer vision. It has lead to revolutionise and transform the way we perceive and analyze visual data. Traditionally, Convolutional Neural Networks (CNNs) have been the go-to models for visual tasks, but ViTs offer a novel alternative. By leveraging the self-attention mechanisms and Transformer architectures, ViTs break the limitations imposed by local receptive fields in CNNs.

This breakthrough enables ViTs to capture global dependencies and long-range interactions within an image. This leads to remarkable performance improvements in various computer vision tasks, including image classification, object detection, and image generation.

With their ability to effectively model high-dimensional visual data. ViTs are revolutionizing the field of CV and paving the way for new possibilities.

This article was published as a part of the Data Science Blogathon.

Neural Networks

Neural networks are algorithms that inspires the structure and function of the human brain. They are an effective tool for addressing complicated issues like image identification, audio recognition, natural language processing, and many more. A neural network’s architecture relates to how the neurons are organized and connected. Numerous neural network topologies exist, such as feedforward networks, recurrent neural networks (RNNs), convolutional neural networks (CNNs), and transformers.

neural networks | vision transformers | computer vision

Feedforward networks are the most basic type of neural network architecture, often called multi-layer perceptron (MLPs). They comprise three layers: an input layer, one or more hidden layers, and an output layer. Each layer’s neurons are fully coupled to the following layer’s neurons. And each neuron applies a non-linear activation function to its input.

RNNs are intended to process data sequences such as time series or natural language text. They have recurrent connections that allow information to be transmitted from one-time step to the next. They can learn data dependencies over time. CNNs are built to handle spatial data, such as pictures. They extract features from the input data using convolutional layers. Then, pooling layers to lower the dimensionality of the features, and then fully connected layers to generate the final prediction.

Transformers

On the other hand, transformers are a type of neural network architecture that processes incoming data through self-attention techniques. The network’s ability to focus on different input areas at different times allows it to capture local and global relationships.

Vision transformers collect spatial relationships in images more effectively than other types of neural networks. This results in state-of-the-art performance on many computer vision applications.

The self-attention mechanism is an essential component of vision transformers because it allows the network to focus on different sections of the input data at other times. It allows to capture both local and global associations.

vision transformers | transformer model | computer vision

In a conventional feedforward neural network, each neuron in a given layer is connected to all neurons in the next layer. However, in a self-attention mechanism, each neuron in a specific layer is connected to all other neurons in that layer, including itself.

The network can then compute a weighted sum of all the neurons in the layer. With weights dependent on the similarity of the current neuron to each of the other neurons.

The self-attention mechanism can be expressed mathematically as follows:

"
"
"

In the context of vision transformers, the input image is divided into a grid of patches, and each patch is treated as an element in the input sequence. The self-attention mechanism is used to build a new set of embeddings representing the image’s local and global spatial relationships.

By using self-attention instead of convolutions, vision transformers may capture long-range dependencies and interactions between patches in an image more effectively. This results in a state-of-the-art performance for many computer vision applications.

Attention Mechanism in Computer Vision (CV)

In the past, attention methods were frequently utilized in computer vision tasks, particularly in picture captioning and object detection. The model needed to focus on different image portions at different times.

For Example

In image captioning, the model must create a natural language description of a picture. The model generates a caption word at each time step, and it must pick which components of the image to attend to construct that word. This is accomplished by using an attention mechanism, which computes a weighted sum of the visual attributes, with weights based on the similarity of the current word to each part of the image.

Similarly, the object detection model must detect entities’ existence and position in an image. A convolutional neural network (CNN) is often used to extract picture information, followed by a region proposal network (RPN) to create candidate object regions. The candidate regions are then refined using an attention technique that attends to relevant parts of the image.

Attention mechanisms are utilized similarly in vision transformers to record the image’s local and global spatial relationships. Instead of using convolutions to extract image features, the input image is partitioned into a grid of patches, with each patch regarded as a sequence element. The self-attention mechanism is then applied to the sequence of patch embeddings to generate a new set of embeddings that represent the spatial relationships between the patches.

Vision transformers can capture long-range dependencies and relationships between patches in the image more effectively by using self-attention rather than convolutions, resulting in state-of-the-art performance on various computer vision tasks such as image classification and object detection. The attention mechanism in vision transformers enables the model to focus on crucial aspects of the image while processing it, allowing it to be more efficient and accurate while dealing with complicated visual input.

Patch-based Processing

Vision transformers use a patch-based approach to image processing, breaking the input image into smaller, fixed-size patches and treating each patch as a single token. This method has both advantages and cons.

Patch based processing in computer vision

One advantage of patch-based processing is that vision transformers may accept inputs of various sizes without extra resizing or cropping. This is especially beneficial for applications like object detection and segmentation, where the size and shape of the objects in the image might change significantly.

Another advantage of patch-based processing is that the self-attention mechanism may attend to interactions between patches throughout the image, allowing for more excellent capture of the global image context. This is especially significant for tasks like scene comprehension or image captioning, where the context and interactions between items in the image are critical for creating accurate descriptions.

However, patch-based processing has several drawbacks. One significant disadvantage is that spatial information is lost because each patch is handled as a separate token, and the relative positions of the patches are not explicitly stored. This can impair performance in tasks that rely substantially on spatial relationships, such as fine-grained object recognition or geometric reasoning.

Another potential disadvantage is the computational and memory costs of processing many patches. To some extent, this can be addressed by employing techniques such as overlapping patches or hierarchical processing, but it remains a substantial difficulty for large-scale applications.

Overall, patch-based processing is crucial for vision transformers, allowing them to attain cutting-edge results on various computer vision benchmarks. However, it is critical to carefully assess the benefits and drawbacks of this strategy for individual applications and investigate techniques to alleviate some of its limits.

Patch Embeddings

A linear projection separates the input image into a grid of non-overlapping patches, with each patch represented as a vector. After that, the patch embeddings are concatenated along the channel dimension to generate a vector sequence sent to the transformer encoder.

Multi-head Self-attention

The transformer encoder comprises several layers of multi-head self-attention, allowing the model to capture local and global interactions between patches. Each multi-head self-attention layer comprises a self-attention mechanism, a normalization layer, and a feedforward network.

Multi-Head Attention

The model’s self-attention mechanism enables it to attend to different parts of the input sequence at other times, allowing it to capture local and global correlations. Each patch embedding is converted into a collection of queries, keys, and values and then used to calculate attention weights. The attention weights are utilized to calculate a weighted sum of the values, which is used as the self-attention layer’s output.

Normalization Layer

After applying the attention mechanism, we pass the output through a normalization layer, which helps stabilize the learning process by ensuring a reasonably consistent distribution of activations across different instances.

vision transformers | computer vision | normalisation layer

The vision transformer learns a hierarchical representation of the input image by stacking many layers of patch embeddings, multi-head self-attention, and feedforward networks. This allows it to capture both low-level features and high-level semantic information.

1. Import Libraries

import torch
import torchvision
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor

2. Importing Dataset

data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 
transform=transforms.ToTensor())

3. Splitting Data

train_size = int(0.8 * len(data))
val_size = len(data) - train_size
train_data, val_data = torch.utils.data.random_split(data, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

4. Defining Model

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

5. Loss and Optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

6. Training Dataset

for epoch in range(10):
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = feature_extractor(inputs)['pixel_values']
        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()

7. Evaluating Dataset

val_loader = torch.utils.data.DataLoader(val_data, batch_size=32)
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in val_loader:
        inputs = feature_extractor(inputs)['pixel_values']
        outputs = model(inputs)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on validation set: %d %%' % (100 * correct / total))

Pre-trained Models

The primary advantage of pre-trained vision transformer models is that they may be fine-tuned for specific applications with limited labelled data. Transfer learning benefits applications with little labelled data, like medical imaging or satellite imagery. By utilizing the model’s pre-trained characteristics, fine-tuning can help boost the accuracy and speed of training for specific jobs.

In order to adapt a pre-trained vision transformer model for specific tasks such as object detection or image segmentation, we substitute the last classification layer with a task-specific layer. During training, we keep the weights of the pre-trained layers unchanged while updating only the weights of the new layer to minimize the task-specific loss function.

Moreover, we can utilize pre-trained models as feature extractors to construct high-dimensional embeddings for subsequent tasks like image retrieval or clustering.

Overall, the availability of pre-trained vision transformer models has dramatically reduced the barriers to entry for computer vision research and application development. By fine-tuning these models for specific tasks or employing them as feature extractors, practitioners can achieve cutting-edge performance using fewer data and computational resources.

Python Snippet: Pre-trained Model

1. Install the Packages

!pip install torch torchvision timm

2. Importing Libraries

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import timm

3. Data Transformation

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

4. Loading Dataset

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

5. Defining Model

model = timm.create_model('deit_base_patch16_224', pretrained=True, num_classes=10)

6. Loss and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

7. Training Dataset

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
print('Finished Training')

8. Evaluating the Dataset

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Interpretability

Vision transformers have the advantage of being more interpretable than typical convolutional neural networks (CNNs). Interpret the models to provide information on how the model generates judgements or predictions. In computer vision, interpretability can help users understand why the model made a particular classification or detection. This is especially significant in applications such as medical imaging, where the model’s accuracy and dependability are critical.

Vision transformers are more interpretable due to the self-attention mechanism used in their architecture. The self-attention method allows the model to focus on different regions of the image, letting the user observe which parts will you use to make predictions. In contrast to traditional CNNs, the model’s intermediate feature maps may be difficult to interpret.

Furthermore, vision transformers can generate saliency maps, which depict the most significant portions of the input image for a specific prediction. This can assist users in better understanding how the model makes decisions and identifying potential flaws or biases in the model’s predictions.

Overall, vision transformer interpretability can be helpful in various applications where understanding the model’s decision-making process is vital. This includes medical imaging, self-driving cars, and other safety-sensitive applications where model accuracy and reliability are crucial.

Hybrid Architectures

The Transformer in a Convolutional Neural Network (T-CNN) is an example of a hybrid architecture for object detection tasks that combines a visual transformer with a CNN. In this design, the CNN extracts low-level features, which are then transmitted to the vision transformer for high-level feature extraction and object detection.

These hybrid designs can offer various benefits, including improved performance, lower computation costs, and greater interpretability. By combining the strengths of vision transformers and CNNs, hybrid architectures can provide cutting-edge performance on a wide range of computer vision applications while also being more interpretable than traditional CNN architectures.

Furthermore, hybrid architectures may make better use of resources such as memory and computing by allowing for concurrent image processing. This is particularly important in applications that require real-time performance, such as autonomous driving.

Overall, hybrid designs that combine vision transformers with other neural network architectures have the potential to push the boundaries of computer vision and enable a wide range of applications that were previously difficult to perform with traditional CNNs or vision transformers alone.

Comparison with Other Techniques

CNNs, like vision transformers, are neural networks used in computer vision tasks. They differ from CNNs in that they process images using patches and the self-attention method, whereas CNNs extract features from images using convolutional filters.

Unlike RNNs, widely used for sequence data, vision transformers are more suited for image data because they can model long-term dependencies between image patches.

Use GNNs to process graph-structured data like social networks or molecules. While vision transformers do not directly deal with graph data, they can be utilized for object detection, where objects can be viewed as nodes in a graph.

In general, each technique possesses its own set of advantages and disadvantages, making them suitable for specific types of data and activities. The appropriate approach is determined by the specific circumstances and characteristics of the data in question.

Advantages

  • Vision Transformers’ transformer architecture provides a versatile and modular foundation for building and adapting models to various requirements. The model’s attention mechanism enables it to learn links between patches, allowing it to gather both local and global information. Furthermore, the transformer’s modular construction allows for architectural changes such as stacking several layers or adding task-specific heads, allowing researchers to tailor the model architecture to individual needs.
  • Vision Transformers can use pretraining on large-scale datasets like ImageNet to learn general visual representations. The model can capture a wide range of visual concepts and attributes by pretraining on big datasets, which can then be fine-tuned on specific downstream applications. Even with insufficient annotated data for the target task, transfer learning with Vision Transformers has shown good performance and generalization across diverse computer vision tasks.

Application

Even though there are many applications, some of the highlighted are as,

  • Image Classification: Image classification is the most typical use of vision transformers, with the purpose of assigning an image to one of several pre-defined categories. Vision transformers have demonstrated competitive or superior performance to standard CNN-based models on various image classification benchmarks, including ImageNet, CIFAR-100, and the recently released ImageNet-21K.
  • Generative Works: Researchers have also utilized vision transformers in generative tasks, aiming to generate new images that resemble a given training dataset. The generation of new images, resembling a given training dataset, is often accomplished using a variant of the transformer architecture called the “GPT-style” transformer. This approach involves training the transformer on a vast corpus of text data, followed by fine-tuning it on image data.

Limitations

Vision Transformer performance strongly relies on the availability of large-scale labelled datasets for pretraining. Pretraining on big datasets, such as ImageNet, is often done by first performing a proxy job (e.g., predicting the position of picture patches) and then fine-tuning the target task. On the other hand, the demand for large-scale labelled data can be challenging, particularly when dealing with specialized or domain-specific activities with limited annotated data.

Conclusion

In computer vision, vision transformers are a relatively new and intriguing breakthrough. They process images using a transformer architecture with a self-attention mechanism, and their findings in image classification, object identification, and image segmentation have been promising.

  • Vision transformers provide several significant advantages, including capturing long-term dependencies, flexibility in processing inputs of varied sizes, and the potential for greater generalization to new data. However, they have disadvantages, such as high computational costs, enormous memory requirements, and training difficulty.
  • Despite these obstacles, I believe vision transformers will continue to play a significant role in computer vision research and applications. As researchers delve deeper into approaches aimed at reducing the computing and memory requirements of vision transformers, enhancing their interpretability and ease of training, we can anticipate even more impressive results and discover new applications for this groundbreaking technology. With ongoing advancements, we can expect vision transformers to continue pushing the boundaries of computer vision and unlocking exciting possibilities in the field.
  • Overall, the invention of vision transformers is an exciting achievement in the science of computer vision, with significant potential for improving our comprehension and capacity to analyze visual data in various fields.

Frequently Asked Questions

Q1. What is a vision transformer in computer vision?

A. A vision transformer in computer vision refers to a specific type of neural network architecture that applies self-attention mechanisms to transform visual data representation. It breaks down the input image into patches, processes them using attention mechanisms, and captures global relationships among these patches to understand the visual content and make predictions.

Q2. What is computer vision in AI?

A. Computer vision in AI refers to the field that focuses on enabling machines to gain visual understanding from images or videos. It involves developing algorithms and models to interpret and extract meaningful information from visual data, enabling applications like object recognition, image generation, and video analysis.

Q3. What is the advantage of a vision transformer?

A. The advantage of a vision transformer is its ability to capture long-range dependencies in visual data, facilitating a better understanding of complex patterns and structures. Compared to traditional convolutional neural networks, this leads to improved performance in tasks such as object detection, segmentation, and image classification.

Q4. What is the disadvantage of a Vision Transformer?

A. The disadvantage of a Vision Transformer lies in its computational complexity and memory requirements, making it less efficient for processing large-scale visual data than traditional convolutional neural networks (CNNs). Training VITs can also be more challenging due to the need for large-scale datasets.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

Premanand S is a dedicated academic with over a decade of research experience, specializing in Bio-signal Processing, Machine Learning, and Deep Learning. He completed his B.Tech in 2009 from Amrita Vishwa Vidyapeetham, Bangalore, and his M.E. in 2011 from Rajalakshmi Engineering College, Chennai, where his thesis focused on Deep Learning for ECG Signal Processing.

He is pursuing his Ph.D. at VIT-Chennai, with a tentative research title of "Deep Learning Approaches for Enhanced ECG Signal Processing and Arrhythmia Classification." His research aims to leverage cutting-edge deep learning techniques to improve the accuracy and efficiency of ECG signal analysis, contributing significantly to cardiac health monitoring.

A recipient of the prestigious TCS-RSP (Research Scholarship) in 2014, Cycle 9, Premanand has become a recognized figure in the academic community. He has delivered several invited talks on Data Science, Machine Learning, and Deep Learning at prominent institutions across India.

In his role as an Assistant Professor at VIT-Chennai, he continues to inspire the next generation of researchers while advancing the boundaries of knowledge in his field.

Responses From Readers

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details