I love working in the deep learning space. It is, quite frankly, a vast field with a plethora of techniques and frameworks to pour over and learn. And the real thrill of building deep learning and computer vision models comes when I watch real-world applications like facial recognition and ball tracking in cricket, among other things.
And one of my favorite computer vision and deep learning concepts is object detection. The ability to build a model that can go through images and tell me what objects are present – it’s a priceless feeling!
When humans look at an image, we recognize the object of interest in a matter of seconds. This is not the case with machines. Hence, object detection is a computer vision problem of locating instances of objects in an image.
Here’s the good news – object detection applications are easier to develop than ever before. The current approaches today focus on the end-to-end pipeline which has significantly improved the performance and also helped to develop real-time use cases.
In this article, I will walk you through how to build an object detection model using the popular TensorFlow API. If you are a newcomer to deep learning, computer vision and the world of object detection, I recommend going through the below resources:
Typically, we follow three steps when building an object detection framework:
First, a deep learning model or algorithm is used to generate a large set of bounding boxes spanning the full image (that is, an object localization component)
Next, visual features are extracted for each of the bounding boxes. They are evaluated and it is determined whether and which objects are present in the boxes based on visual features (i.e. an object classification component)
In the final post-processing step, overlapping boxes are combined into a single bounding box (that is, non-maximum suppression)
That’s it – you’re ready with your first object detection framework!
What is an API? Why do we need an API?
API stands for Application Programming Interface. An API provides developers a set of common operations so that they don’t have to write code from scratch.
Think of an API like the menu in a restaurant that provides a list of dishes along with a description for each dish. When we specify what dish we want, the restaurant does the work and provides us finished dishes. We don’t know exactly how the restaurant prepares that food, and we don’t really need to.
In one sense, APIs are great time savers. They also offer users convenience in many cases. Think about it – Facebook users (including myself!) appreciate the ability to sign into many apps and sites using their Facebook ID. How do you think this works? Using Facebook’s APIs of course!
So in this article, we will look at the TensorFlow API developed for the task of object detection.
TensorFlow Object Detection API
The TensorFlow object detection API is the framework for creating a deep learning network that solves object detection problems.
There are already pretrained models in their framework which they refer to as Model Zoo. This includes a collection of pretrained models trained on the COCO dataset, the KITTI dataset, and the Open Images Dataset. These models can be used for inference if we are interested in categories only in this dataset.
They are also useful for initializing your models when training on the novel dataset. The various architectures used in the pretrained model are described in this table:
MobileNet-SSD
The SSD architecture is a single convolution network that learns to predict bounding box locations and classify these locations in one pass. Hence, SSD can be trained end-to-end. The SSD network consists of base architecture (MobileNet in this case) followed by several convolution layers:
SSD operates on feature maps to detect the location of bounding boxes. Remember – a feature map is of the size Df * Df * M. For each feature map location, k bounding boxes are predicted. Each bounding box carries with it the following information:
SSD does not predict the shape of the box, rather just where the box is. The k bounding boxes each have a predetermined shape. The shapes are set prior to actual training. For example, in the figure above, there are 4 boxes, meaning k=4.
Loss in MobileNet-SSD
With the final set of matched boxes, we can compute the loss like this:
L = 1/N (L class + L box)
Here, N is the total number of matched boxes. L class is the softmax loss for classification and ‘L box’ is the L1 smooth loss representing the error of matched boxes. L1 smooth loss is a modification of L1 loss which is more robust to outliers. In the event that N is 0, the loss is set to 0 as well.
MobileNet
The MobileNet model is based on depthwise separable convolutions which are a form of factorized convolutions. These factorize a standard convolution into a depthwise convolution and a 1 × 1 convolution called a pointwise convolution.
For MobileNets, the depthwise convolution applies a single filter to each input channel. The pointwise convolution then applies a 1 × 1 convolution to combine the outputs of the depthwise convolution.
A standard convolution both filters and combines inputs into a new set of outputs in one step. The depthwise separable convolution splits this into two layers – a separate layer for filtering and a separate layer for combining. This factorization has the effect of drastically reducing computation and model size.
How to load the model?
Below is the step-by-step process to follow on Google Colab for you to just visualize object detection easily. You can follow along with the code as well.
Install the Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Get tensorflow/models or cd to parent directory of the repository:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Compile protobufs and install the object_detection package:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Label maps map indices to category names so that when our convolution network predicts 5, we know that this corresponds to an airplane:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
For the sake of simplicity, we will test on 2 images:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Check the model’s input signature (it expects a batch of 3-color images of type int8):
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add a wrapper function to call the model and cleanup the outputs:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Below is the example image tested on ssd_mobilenet_v1_coco (MobileNet-SSD trained on the COCO dataset):
Inception-SSD
The architecture of the Inception-SSD model is similar to that of the above MobileNet-SSD one. The difference is that the base architecture here is the Inception model. To know more about the inception network, go here – Understanding the Inception Network from Scratch.
How to load the model?
Just change the model name in the Detection part of the API:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Then make the prediction using the steps we followed earlier. Voila!
Faster RCNN
State-of-the-art object detection networks depend on region proposal algorithms to hypothesize object locations. Advances like SPPnet and Fast R-CNN have reduced the running time of these detection networks, exposing region proposal computation as a bottleneck.
In Faster RCNN, we feed the input image to the convolutional neural network to generate a convolutional feature map. From the convolutional feature map, we identify the region of proposals and warp them into squares. And by using an RoI (Region Of Interest layer) pooling layer, we reshape them into a fixed size so that it can be fed into a fully connected layer.
From the RoI feature vector, we use a softmax layer to predict the class of the proposed region and also the offset values for the bounding box.
Just change the model name in the Detection part of the API again:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Then make the prediction using the same steps as we followed above. Below is the example image when given to a Faster RCNN model:
As you can see, this is much better than the SSD-Mobilenet model. But it comes with a tradeoff – it is much slower than the previous model. These are the kind of decisions you’ll need to make when you’re picking the right object detection model for your deep learning and computer vision project.
Frequently Asked Questions
Q1. What is TensorFlow in object detection?
A. TensorFlow is an open-source machine learning framework that includes a comprehensive library of tools and algorithms for building and training machine learning models. TensorFlow’s object detection API provides pre-trained models and tools for training custom object detection models, allowing developers to quickly build and deploy applications for identifying and tracking objects in images and videos.
Q2. How to train object detection model with TensorFlow?
A. To train an object detection model with TensorFlow, the following steps can be taken:
1. Collect and label a dataset of images. 2. Choose a pre-trained model or create a custom model architecture. 3. Configure and train the model using TensorFlow’s object detection API. 4. Evaluate the model’s performance and fine-tune it as needed. 5. Export the trained model for deployment in a production environment.
Tools such as TensorFlow’s object detection API and libraries like OpenCV can simplify and streamline the process of training object detection models.
Which Object Detection Model Should you Choose?
Depending on your specific requirement, you can choose the right model from the TensorFlow API. If we want a high-speed model that can work on detecting video feed at a high fps, the single-shot detection (SSD) network works best. As its name suggests, the SSD network determines all bounding box probabilities in one go; hence, it is a vastly faster model.
However, with single-shot detection, you gain speed at the cost of accuracy. With FasterRCNN, we’ll get high accuracy but slow speed. So explore and in the process, you’ll realize how powerful this TensorFlow API can be!
Aspiring Data Scientist with a passion to play and wrangle with data and get insights from it to help the community know the upcoming trends and products for their better future.With an ambition to develop product used by millions which makes their life easier and better.
Hi Alakh,
Thank you for the blog. It was helpful to understand the terminology and some of the concepts.
I have a question. I am now starting to get into object detection with tensorflow, however object detection api tutorials page suggests that models are supported up to version 1.15. And also, the example you have here and on the tutorials page are based on the model from 2017 so I had the feeling that I am not learning state-of-the-art. But since your blog post is only 2 months old, I found it hard to understand what versions or tutorials should I start with or what are the latest models that I can use and train with object detection API.
Thanks for you help in advance.
Suman
Hi
How do I remove pre-trained weights in TFOD? I want to train it from scratch
Thanks
We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.
Show details
Powered By
Cookies
This site uses cookies to ensure that you get the best experience possible. To learn more about how we use cookies, please refer to our Privacy Policy & Cookies Policy.
brahmaid
It is needed for personalizing the website.
csrftoken
This cookie is used to prevent Cross-site request forgery (often abbreviated as CSRF) attacks of the website
Identityid
Preserves the login/logout state of users across the whole site.
sessionid
Preserves users' states across page requests.
g_state
Google One-Tap login adds this g_state cookie to set the user status on how they interact with the One-Tap modal.
MUID
Used by Microsoft Clarity, to store and track visits across websites.
_clck
Used by Microsoft Clarity, Persists the Clarity User ID and preferences, unique to that site, on the browser. This ensures that behavior in subsequent visits to the same site will be attributed to the same user ID.
_clsk
Used by Microsoft Clarity, Connects multiple page views by a user into a single Clarity session recording.
SRM_I
Collects user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
SM
Use to measure the use of the website for internal analytics
CLID
The cookie is set by embedded Microsoft Clarity scripts. The purpose of this cookie is for heatmap and session recording.
SRM_B
Collected user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
_gid
This cookie is installed by Google Analytics. The cookie is used to store information of how visitors use a website and helps in creating an analytics report of how the website is doing. The data collected includes the number of visitors, the source where they have come from, and the pages visited in an anonymous form.
_ga_#
Used by Google Analytics, to store and count pageviews.
_gat_#
Used by Google Analytics to collect data on the number of times a user has visited the website as well as dates for the first and most recent visit.
collect
Used to send data to Google Analytics about the visitor's device and behavior. Tracks the visitor across devices and marketing channels.
AEC
cookies ensure that requests within a browsing session are made by the user, and not by other sites.
G_ENABLED_IDPS
use the cookie when customers want to make a referral from their gmail contacts; it helps auth the gmail account.
test_cookie
This cookie is set by DoubleClick (which is owned by Google) to determine if the website visitor's browser supports cookies.
_we_us
this is used to send push notification using webengage.
WebKlipperAuth
used by webenage to track auth of webenagage.
ln_or
Linkedin sets this cookie to registers statistical data on users' behavior on the website for internal analytics.
JSESSIONID
Use to maintain an anonymous user session by the server.
li_rm
Used as part of the LinkedIn Remember Me feature and is set when a user clicks Remember Me on the device to make it easier for him or her to sign in to that device.
AnalyticsSyncHistory
Used to store information about the time a sync with the lms_analytics cookie took place for users in the Designated Countries.
lms_analytics
Used to store information about the time a sync with the AnalyticsSyncHistory cookie took place for users in the Designated Countries.
liap
Cookie used for Sign-in with Linkedin and/or to allow for the Linkedin follow feature.
visit
allow for the Linkedin follow feature.
li_at
often used to identify you, including your name, interests, and previous activity.
s_plt
Tracks the time that the previous page took to load
lang
Used to remember a user's language setting to ensure LinkedIn.com displays in the language selected by the user in their settings
s_tp
Tracks percent of page viewed
AMCV_14215E3D5995C57C0A495C55%40AdobeOrg
Indicates the start of a session for Adobe Experience Cloud
s_pltp
Provides page name value (URL) for use by Adobe Analytics
s_tslv
Used to retain and fetch time since last visit in Adobe Analytics
li_theme
Remembers a user's display preference/theme setting
li_theme_set
Remembers which users have updated their display / theme preferences
We do not use cookies of this type.
_gcl_au
Used by Google Adsense, to store and track conversions.
SID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SAPISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
__Secure-#
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
APISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
HSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
DV
These cookies are used for the purpose of targeted advertising.
NID
These cookies are used for the purpose of targeted advertising.
1P_JAR
These cookies are used to gather website statistics, and track conversion rates.
OTZ
Aggregate analysis of website visitors
_fbp
This cookie is set by Facebook to deliver advertisements when they are on Facebook or a digital platform powered by Facebook advertising after visiting this website.
fr
Contains a unique browser and user ID, used for targeted advertising.
bscookie
Used by LinkedIn to track the use of embedded services.
lidc
Used by LinkedIn for tracking the use of embedded services.
bcookie
Used by LinkedIn to track the use of embedded services.
aam_uuid
Use these cookies to assign a unique ID when users visit a website.
UserMatchHistory
These cookies are set by LinkedIn for advertising purposes, including: tracking visitors so that more relevant ads can be presented, allowing users to use the 'Apply with LinkedIn' or the 'Sign-in with LinkedIn' functions, collecting information about how visitors use the site, etc.
li_sugr
Used to make a probabilistic match of a user's identity outside the Designated Countries
MR
Used to collect information for analytics purposes.
ANONCHK
Used to store session ID for a users session to ensure that clicks from adverts on the Bing search engine are verified for reporting purposes and for personalisation
We do not use cookies of this type.
Cookie declaration last updated on 24/03/2023 by Analytics Vidhya.
Cookies are small text files that can be used by websites to make a user's experience more efficient. The law states that we can store cookies on your device if they are strictly necessary for the operation of this site. For all other types of cookies, we need your permission. This site uses different types of cookies. Some cookies are placed by third-party services that appear on our pages. Learn more about who we are, how you can contact us, and how we process personal data in our Privacy Policy.
Hi Alakh, Thank you for the blog. It was helpful to understand the terminology and some of the concepts. I have a question. I am now starting to get into object detection with tensorflow, however object detection api tutorials page suggests that models are supported up to version 1.15. And also, the example you have here and on the tutorials page are based on the model from 2017 so I had the feeling that I am not learning state-of-the-art. But since your blog post is only 2 months old, I found it hard to understand what versions or tutorials should I start with or what are the latest models that I can use and train with object detection API. Thanks for you help in advance.
Hi How do I remove pre-trained weights in TFOD? I want to train it from scratch Thanks