Deep learning is a vast field but there are a few common challenges most of us face when building models
Here, we talk about 4 such challenges and tricks to improve your deep learning model’s performance
This is a hands-on code-focused article so get your Python IDE ready and improve your deep learning model!
Introduction
I’ve spent the majority of the last two years working almost exclusively in the deep learning space. It’s been quite an experience – worked on multiple projects including image and video data related ones.
Before that, I was on the fringes – I skirted around deep learning concepts like object detection and face recognition – but didn’t take a deep dive until late 2017. I’ve come across a variety of challenges during this time. And I want to talk about four very common ones that most deep learning practitioners and enthusiasts face in their journey.
If you’ve worked in a deep learning project before, you’ll be able to relate with all of these obstacles we’ll soon see. And here’s the good news – overcoming them is not as difficult as you might think!
We’ll take a very hands-on approach in this article. First, we’ll establish the four common challenges I mentioned above. Then we’ll dive straight into the Python code and learn key tips and tricks to combat and overcome these challenges. There’s a lot to unpack here so let’s get the ball rolling!
You should definitely check out the below popular course if you’re new to deep learning:
Brief Overview of the Vehicle Classification Case Study
Understanding Each Challenge and How to Overcome it to Improve your Deep Learning Model’s Performance
Case Study: Improving the Performance of our Vehicle Classification Model
Common Challenges with Deep Learning Models
Deep Learning models usually perform really well on most kinds of data. And when it comes to image data, deep learning models, especially convolutional neural networks (CNNs), outperform almost all other models.
My usual approach is to use a CNN model whenever I encounter an image related project, like an image classification one.
This approach works well but there are cases when CNN or other deep learning models fail to perform. I have encountered it a couple of times. My data was good, the architecture of the model was also properly defined, the loss function and optimizers were also set correctly but my model kept falling short of what I expected.
And this is a common challenge that most of us face while working with deep learning models.
As I mentioned above, I will be covering four such challenges:
Paucity of Data available for training
Overfitting
Underfitting
High training time
Before diving deeper and understanding these challenges, let’s quickly look at the case study which we’ll solve in this article.
Brief Overview of the Vehicle Classification Case Study
This article is part of the PyTorch for beginners series I’ve been writing about. You can check out the previous three articles here (we’ll be referencing a few things from there):
We’ll be picking up the case study which we saw in the previous article. The aim here is to classify the images of vehicles as emergency or non-emergency.
Let’s first quickly build a CNN model which we will use as a benchmark. We will also try to improve the performance of this model. The steps are pretty straightforward and we have already seen them a couple of times in the previous articles.
Hence, I will not be diving deep into each step here. Instead, we will focus on the code and you can always check out this in more detail in the previous articles which I’ve linked above. You can get the dataset from here.
Here is the complete code to build a CNN model for our vehicle classification project.
Importing the libraries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# importing the libraries
import pandas as pd
import numpy as np
from tqdm import tqdm
# for reading and displaying images
from skimage.io import imread
from skimage.transform import resize
import matplotlib.pyplot as plt
# loading dataset
train = pd.read_csv('train.csv')
print(train.head())
# loading training images
train_img = []
Creating the training and validation set
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This is our CNN model. The training accuracy is around 88% and the validation accuracy is close to 70%.
We will try to improve the performance of this model. But before we get into that, let’s spend some time understanding the different challenges which might be the reason behind this low performance.
Deep Learning Challenge #1: Paucity of Data Available for Training our Model
Deep learning models usually require a lot of data for training. In general, the more the data, the better will be the performance of the model. The problem with a lack of data is that our deep learning model might not learn the pattern or function from the data and hence it might not give a good performance on unseen data.
If you look at the case study of vehicle classification, we only have around 1650 images and hence the model was unable to perform well on the validation set. The challenge of less data is very common while working with computer vision and deep learning models.
And as you can imagine, gathering data manually is a tedious and time taking task. So, instead of spending days to collect data, we can make use of data augmentation techniques.
Data augmentation is the process of generating new data or increasing the data for training the model without actually collecting new data.
There are multiple data augmentation techniques for image data and you can refer to this article which explains these techniques explicitly. Some of the commonly used augmentation techniques are rotation, shear, flip, etc.
It is a very vast topic and hence I have decided to dedicate a complete article to it. My plan is to cover these techniques along with their implementation in PyTorch in my next article.
Deep Learning Challenge #2: Model Overfitting
I’m sure you’ve heard of overfitting before. It’s one of the most common challenges (and mistakes) aspiring data scientists make when they’re new to machine learning. But this issue actually transcends fields – it applies to deep learning as well.
A model is said to overfit when it performs really well on the training set but the performance drops on the validation set (or unseen data).
For example, let’s say we have a training and a validation set. We train the model using the training data and check its performance on both the training and validation sets (evaluation metric is accuracy). The training accuracy comes out to be 95% whereas the validation accuracy is 62%. Sounds familiar?
Since the validation accuracy is way less than the training accuracy, we can infer that the model is overfitting. The below illustration will give you a better understanding of what overfitting is:
The portion marked in blue in the above image is the overfitting model since training error is very less and the test error is very high. The reason for overfitting is that the model is learning even the unnecessary information from the training data and hence it performs really well on the training set.
But when new data is introduced, it fails to perform. We can introduce dropout to the model’s architecture to overcome this problem of overfitting.
Using dropout, we randomly switch off some of the neurons of the neural network. Let’s say we add a dropout of 0.5 to a layer which originally had 20 neurons. So, 10 neurons out of these 20 will be removed and we end up with a less complex architecture.
Hence, the model will not learn complex patterns and we can avoid overfitting. If you wish to learn more about dropouts, feel free to go through this article. Let’s now add a dropout layer to our architecture and check its performance.
Model Architecture
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Here, I have added a dropout layer in each convolutional block. The default value is 0.5 which means that half of the neurons will be randomly switched off. This is a hyperparameter and you can pick any value between 0 and 1.
Next, we will define the parameters of the model like the loss function, optimizer, and learning rate.
Model Parameters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Here, you can see that the default value of p in dropout is 0.5. Finally, let’s train the model after adding the dropout layer:
Training the model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Let’s now check the training and validation accuracy using this trained model.
Checking model performance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The table above represents the accuracy without and with dropout. If you look at the training and validation accuracy of the model without dropout, they are not in sync. Training accuracy is too high whereas the validation accuracy is less. Hence, this was a possible case of overfitting.
When we introduced dropout, both the training and validation accuracies came in sync. Hence, if your model is overfitting, you can try to add dropout layers to it and reduce the complexity of the model.
The amount of dropout to be added is a hyperparameter and you can play around with that value. Let’s now look at another challenge.
Deep Learning Challenge #3: Model Underfitting
Deep learning models can underfit as well, as unlikely as it sounds.
Underfitting is when the model is not able to learn the patterns from the training data itself and hence the performance on the training set is low.
This might be due to multiple reasons, such as not enough data to train, architecture is too simple, the model is trained for less number of epochs, etc.
To overcome underfitting, you can try the below solutions:
Increase the training data
Make a complex model
Increase the training epochs
For our problem, underfitting is not an issue and hence we will move forward to the next method for improving a deep learning model’s performance.
Deep Learning Challenge #4: Training Time is too High
There are cases when you might find that your neural network is taking a lot of time to converge. The main reason behind this is the change in the distribution of inputs to the layers of the neural network.
During the training process, the weights of each layer of the neural network change, and hence the activations also change. Now, these activations are the inputs for the next layer and hence the distribution changes with each successive iteration.
Due to this change in distribution, each layer has to adapt to the changing inputs – that’s why the training time increases.
To overcome this problem, we can apply batch normalization wherein we normalize the activations of hidden layers and try to make the same distribution.
You can read more about batch normalization in this article.
Let’s now add batchnorm layers to the architecture and check how it performs for the vehicle classification problem:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Clearly, the model is able to learn very quickly. We got a training loss of 0.3386 in the 5th epoch itself, whereas the training loss after the 25th epoch was 0.3851 (when we did not use batch normalization).
So, the introduction of batch normalization has definitely reduced the training time. Let’s check the performance on the training and validation sets:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Adding batch normalization reduced the training time but we have an issue here. Can you figure out what it is? The model is now overfitting since we got an accuracy of 91% on training and 63% on the validation set. Remember – we did not add the dropout layer in the latest model.
These are some of the tricks we can use to improve the performance of our deep learning model. Let’s now combine all the techniques that we have learned so far.
Case Study: Improving the Performance of the Vehicle Classification Model
We have seen how dropout and batch normalization help to reduce overfitting and quicken the training process. It’s finally time to combine all these techniques together and build a model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The validation accuracy has clearly improved to 73%. Awesome!
End Notes
In this article, we looked at different challenges that we can face when using deep learning models like CNNs. We also learned the solutions to all these challenges and finally, we built a model using these solutions.
The accuracy of the model on the validation set improved after we added these techniques to the model. There is always scope for improvement and here are some of the things that you can try out:
Tune the dropout rate
Add or reduce the number of convolutional layers
Add or reduce the number of dense layers
Tune the number of neurons in hidden layers, etc.
Do share your results in the comments section below. And if you’re interested in dabbling in the world of deep learning, make sure you check out the below comprehensive course:
We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.
Show details
Powered By
Cookies
This site uses cookies to ensure that you get the best experience possible. To learn more about how we use cookies, please refer to our Privacy Policy & Cookies Policy.
brahmaid
It is needed for personalizing the website.
csrftoken
This cookie is used to prevent Cross-site request forgery (often abbreviated as CSRF) attacks of the website
Identityid
Preserves the login/logout state of users across the whole site.
sessionid
Preserves users' states across page requests.
g_state
Google One-Tap login adds this g_state cookie to set the user status on how they interact with the One-Tap modal.
MUID
Used by Microsoft Clarity, to store and track visits across websites.
_clck
Used by Microsoft Clarity, Persists the Clarity User ID and preferences, unique to that site, on the browser. This ensures that behavior in subsequent visits to the same site will be attributed to the same user ID.
_clsk
Used by Microsoft Clarity, Connects multiple page views by a user into a single Clarity session recording.
SRM_I
Collects user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
SM
Use to measure the use of the website for internal analytics
CLID
The cookie is set by embedded Microsoft Clarity scripts. The purpose of this cookie is for heatmap and session recording.
SRM_B
Collected user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
_gid
This cookie is installed by Google Analytics. The cookie is used to store information of how visitors use a website and helps in creating an analytics report of how the website is doing. The data collected includes the number of visitors, the source where they have come from, and the pages visited in an anonymous form.
_ga_#
Used by Google Analytics, to store and count pageviews.
_gat_#
Used by Google Analytics to collect data on the number of times a user has visited the website as well as dates for the first and most recent visit.
collect
Used to send data to Google Analytics about the visitor's device and behavior. Tracks the visitor across devices and marketing channels.
AEC
cookies ensure that requests within a browsing session are made by the user, and not by other sites.
G_ENABLED_IDPS
use the cookie when customers want to make a referral from their gmail contacts; it helps auth the gmail account.
test_cookie
This cookie is set by DoubleClick (which is owned by Google) to determine if the website visitor's browser supports cookies.
_we_us
this is used to send push notification using webengage.
WebKlipperAuth
used by webenage to track auth of webenagage.
ln_or
Linkedin sets this cookie to registers statistical data on users' behavior on the website for internal analytics.
JSESSIONID
Use to maintain an anonymous user session by the server.
li_rm
Used as part of the LinkedIn Remember Me feature and is set when a user clicks Remember Me on the device to make it easier for him or her to sign in to that device.
AnalyticsSyncHistory
Used to store information about the time a sync with the lms_analytics cookie took place for users in the Designated Countries.
lms_analytics
Used to store information about the time a sync with the AnalyticsSyncHistory cookie took place for users in the Designated Countries.
liap
Cookie used for Sign-in with Linkedin and/or to allow for the Linkedin follow feature.
visit
allow for the Linkedin follow feature.
li_at
often used to identify you, including your name, interests, and previous activity.
s_plt
Tracks the time that the previous page took to load
lang
Used to remember a user's language setting to ensure LinkedIn.com displays in the language selected by the user in their settings
s_tp
Tracks percent of page viewed
AMCV_14215E3D5995C57C0A495C55%40AdobeOrg
Indicates the start of a session for Adobe Experience Cloud
s_pltp
Provides page name value (URL) for use by Adobe Analytics
s_tslv
Used to retain and fetch time since last visit in Adobe Analytics
li_theme
Remembers a user's display preference/theme setting
li_theme_set
Remembers which users have updated their display / theme preferences
We do not use cookies of this type.
_gcl_au
Used by Google Adsense, to store and track conversions.
SID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SAPISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
__Secure-#
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
APISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
HSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
DV
These cookies are used for the purpose of targeted advertising.
NID
These cookies are used for the purpose of targeted advertising.
1P_JAR
These cookies are used to gather website statistics, and track conversion rates.
OTZ
Aggregate analysis of website visitors
_fbp
This cookie is set by Facebook to deliver advertisements when they are on Facebook or a digital platform powered by Facebook advertising after visiting this website.
fr
Contains a unique browser and user ID, used for targeted advertising.
bscookie
Used by LinkedIn to track the use of embedded services.
lidc
Used by LinkedIn for tracking the use of embedded services.
bcookie
Used by LinkedIn to track the use of embedded services.
aam_uuid
Use these cookies to assign a unique ID when users visit a website.
UserMatchHistory
These cookies are set by LinkedIn for advertising purposes, including: tracking visitors so that more relevant ads can be presented, allowing users to use the 'Apply with LinkedIn' or the 'Sign-in with LinkedIn' functions, collecting information about how visitors use the site, etc.
li_sugr
Used to make a probabilistic match of a user's identity outside the Designated Countries
MR
Used to collect information for analytics purposes.
ANONCHK
Used to store session ID for a users session to ensure that clicks from adverts on the Bing search engine are verified for reporting purposes and for personalisation
We do not use cookies of this type.
Cookie declaration last updated on 24/03/2023 by Analytics Vidhya.
Cookies are small text files that can be used by websites to make a user's experience more efficient. The law states that we can store cookies on your device if they are strictly necessary for the operation of this site. For all other types of cookies, we need your permission. This site uses different types of cookies. Some cookies are placed by third-party services that appear on our pages. Learn more about who we are, how you can contact us, and how we process personal data in our Privacy Policy.
All of your notes are fantastic. Now I need image captioning codes and details of note