Image Recognition Using Pytorch Lightning

keegan Last Updated : 20 Jul, 2021
5 min read

This article was published as a part of the Data Science Blogathon

About

Pytorch – lightning

Beginners often get intimidated by the amount of coding required for Deep Learning. This can often be due to complicated code and poor documentation for this reason even veterans in data science often have trouble understanding code. This is why I recommend Pytorch Lightning an open-source library that inherits Pytorch. Pytorch lightning automates a lot of the coding that comes with deep learning and neural networks so you can focus on model building. Pytorch-lightning also helps in writing cleaner code the is easily reproducible. For more information Check the official Pytorch -Lightning Website.

The Data

The dataset we will use here is the Yoga-poses dataset available on Kaggle. This dataset has already been structured in a way that will make building the model easier. The dataset has two main folders “Train” and “Test” that each contains 5 sub-folders the 5 sub-folders contain Images and the class of each Image is the name of the 5 sub-folders. The Dataset is small compared to other image datasets so we will be using data augmentation for the pre-processing. I’d recommend running this on a remote notebook like Kaggle notebooks as it can be computationally expensive to run any Image recognition model on a Local notebook. Now let’s get coding.

The Model

Prerequisites

Before we start coding if you want to follow along, you will need to install Pytorch -Lightning in case you’re running on a local environment all notebooks running on Kaggle or Colab should Already have it installed.

To install on a local python environment

pip install pytorch-lightning

To install it on a local conda environment

conda install -c conda-forge pytorch-lightning

I’d also recommend installing torchvision and cv2 to easily pre-process Image Data

Dependencies

Let’s run all the imports we will require to get started.

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
import torchvision
import cv2
import torchvision.transforms as transforms
import os
from random import randint
!jupyter nbextension enable --py widgetsnbextension

If you cannot run this cell, one or more libraries haven’t been installed on the environment.

 

Preparing the Data

# declaring the path of the train and test folders
train_path = "../input/yoga-poses-dataset/DATASET/TRAIN"
test_path = "../input/yoga-poses-dataset/DATASET/TEST"
classes_dir_data = os.listdir(base_path)
num_of_classes = len(classes_dir_data)
print("Total Number of Classes :" , num_of_classes)
num = 0
classes_dict = {}
num_dict = {}
for c in  classes_dir_data:
    classes_dict[c] = num
    num_dict[num] = c
    num = num +1
"""
num_dict contains a dictionary of the classes numerically and it's corresponding classes.
classes_dict contains a dictionary of the classes and the coresponding values numerically.
"""
num_of_classes = len(classes_dir_data)
classes_dict

output:-

 

The Image Dataset

#creating the dataset

#dataset

class Image_Dataset(Dataset):

    def __init__(self,classes,image_base_dir,transform = None, target_transform = None):

        """

        classes:The classes in the dataset

        image_base_dir:The directory of the folders containing the images

        transform:The trasformations for the Images

        Target_transform:The trasformations for the target

        """

        self.img_labels = classes

        self.imge_base_dir = image_base_dir

        self.transform = transform

        self.target_transform = target_transform

    def __len__(self):

        return len(self.img_labels)

    def __getitem__(self,idx):

        img_dir_list = os.listdir(os.path.join(self.imge_base_dir,self.img_labels[idx]))

        image_path = img_dir_list[randint(0,len(img_dir_list)-1)]

        #print(image_path)

        image_path = os.path.join(self.imge_base_dir,self.img_labels[idx],image_path)

        image = cv2.imread(image_path)

        if self.transform:

            image = self.transform(image)

        if self.transform:

            label = self.target_transform(self.img_labels[idx])

        return image,label

Transformers

All the transformations that will be run on this dataset. Basic transformations show the minimum transformations required to pass the data to the model using it we can quickly make a pipeline.

basic_transformations = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((size,size)),
        transforms.Grayscale(1),
    transforms.ToTensor()])
training_transformations = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((size,size)),
    transforms.RandomRotation(degrees = 45),
    transforms.RandomHorizontalFlip(p = 0.005),
        transforms.Grayscale(1),
    transforms.ToTensor()
])
def target_transformations(x):
    return torch.tensor(classes_dict.get(x))

Data Module

This Pytorch Lightning module will make passing values to the model easier for us

class YogaDataModule(pl.LightningDataModule):

    def __init__(self):

            super().__init__()            

    def prepare_data(self):

        self.train = Image_Dataset(classes_dir_data,train_path,training_transformations,target_transformations)

        self.valid = Image_Dataset(classes_dir_data,test_path,basic_transformations,target_transformations)

        self.test = Image_Dataset(classes_dir_data,test_path,basic_transformations,target_transformations)

    def train_dataloader(self):

        return DataLoader(self.train,batch_size = 64,shuffle = True)

    def val_dataloader(self):  

        return DataLoader(self.valid,batch_size = 64,shuffle = True)

    def test_dataloader(self):

        return DataLoader(self.test,batch_size = 64,shuffle = True)

Model

All the convolutions in the model retain the original input dimensions. The training_step and validation_step will handle the training and validation of the data. On each epoch, the model will return the best model. If you want to measure the metrics just call self.log() and the metrics will be saved on your preferred logger(be careful while using the logger on each step consumes memory). For more information about convolutions, I’d recommend checking out deep lizard’s free course on Deep Learning

class YogaModel(LightningModule):

    def __init__(self):

        super().__init__()

        """

        The convolutions are arranged in such a way that the image maintain the x and y dimensions. only the channels change

        """

        self.layer_1 = nn.Conv2d(in_channels = 1,out_channels = 3,kernel_size = (3,3),padding = (1,1),stride = (1,1))

        self.layer_2 = nn.Conv2d(in_channels = 3,out_channels = 6,kernel_size = (3,3),padding = (1,1),stride = (1,1))

        self.layer_3 = nn.Conv2d(in_channels = 6,out_channels = 12,kernel_size = (3,3),padding = (1,1),stride = (1,1))

        self.pool = nn.MaxPool2d(kernel_size = (3,3),padding = (1,1),stride = (1,1))

        self.layer_5 = nn.Linear(12*50*50,1000)#the input dimensions are (Number of dimensions * height * width)

        self.layer_6 = nn.Linear(1000,100)

        self.layer_7 = nn.Linear(100,50)

        self.layer_8 = nn.Linear(50,10)

        self.layer_9 = nn.Linear(10,5)

    def forward(self,x):

        """

        x is the input data

        """

        x = self.layer_1(x)

        x = self.pool(x)

        x = self.layer_2(x)

        x = self.pool(x)

        x = self.layer_3(x)

        x = self.pool(x)

        x = x.view(x.size(0),-1)

        print(x.size())

        x = self.layer_5(x)

        x = self.layer_6(x)

        x = self.layer_7(x)

        x = self.layer_8(x)

        x = self.layer_9(x)

        return x

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(),lr = 1e-7)

        return optimizer

    """

    The Pytorch-Lightning module handles all the iterations of the epoch

    """

    def training_step(self,batch,batch_idx):

        x,y = batch

        y_pred = self(x)

        loss = F.cross_entropy(y_pred,y)

        return loss

    def validation_step(self,batch,batch_idx):

        x,y = batch

        y_pred = self(x)

        loss = F.cross_entropy(y_pred,y)

        return loss

    def test_step(self,batch,batch_idx):

        x,y = batch

        y_pred = self(x)

        loss = F.cross_entropy(y_pred,y)

        self.log("loss",loss)

        return loss

Training

Now we will finally train the model. Pytorch lightning makes using hardware easy just declare the number of CPU’s and GPU’s you want to use for the model and Lightning will Handle the rest

%%time # This cell 

from pytorch_lightning import Trainer

model = YogaModel()

module = YogaDataModule()

trainer = Trainer(max_epochs=1 , cpu = 1)#Don't go over 10000 - 100000 or it will take 5 - 53+ hours to iterate

trainer.fit(model,module)

output:-

training | Pytorch Lightning

Testing

The final cell will check the loss of the model on unseen data

trainer.test()

output:-

testing | Pytorch Lightning

 

Notes

improving the model

  • If the loss of the train set is very high it means the model is under-fitting. To decrease the loss increase the number of max epochs in the model or learning rate. You can also add more non-linear layers to the model.
  • If the loss on the test set is high it means the model has over-fitted to the train set decreasing the number of epochs or increasing the learning rate or increasing the number of dropout layers should do the trick.
  • This notebook hasn’t used any callback methods. To check out the callback methods available in lightning Check out their official website
The media shown in this article are not owned by Analytics Vidhya and are used at the Author’s discretion.

Responses From Readers

Clear

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details