Object detection is one of the most widely studied topics in the computer vision community. It’s has been breaking into various industries with use cases from image security, surveillance, automated vehicle systems to machine inspection.
Currently, deep learning-based object detection can be majorly classified into two groups:-
One-stage detectors that are applied over a regular, dense sampling of anchor boxes (possible object locations) have the potential to be faster and simpler but have trailed the accuracy of two-stage detectors because of extreme class imbalance encountered during training.
FAIR has released a paper in 2018, in which they introduced the concept of Focal loss to handle this class imbalance problem with their one stage detector called RetinaNet.
Before we deep dive into the nitty-gritty of Focal loss, let’s First, understand what is this class imbalance problem and the possible problems caused by it.
Both classic one stage detection methods, like boosted detectors, DPM & more recent methods like SSD evaluate almost 104 to 105 candidate locations per image but only a few locations contain objects (i.e. Foreground) and rest are just background objects. This leads to the class imbalance problem.
This imbalance causes two problems –
In simple words, Focal Loss (FL) is an improved version of Cross-Entropy Loss (CE) that tries to handle the class imbalance problem by assigning more weights to hard or easily misclassified examples (i.e. background with noisy texture or partial object or the object of our interest ) and to down-weight easy examples (i.e. Background objects).
So Focal Loss reduces the loss contribution from easy examples and increases the importance of correcting misclassified examples.)
So, let’s first understand what Cross-Entropy loss for binary classification.
The idea behind Cross-Entropy loss is to penalize the wrong predictions more than to reward the right predictions.
Cross entropy loss for binary classification is written as follows-
Where –
Yact = Actual Value of Y
Ypred = Predicted Value of Y
For Notational convenience, let’s write Ypred as p & Yact as Y.
Y ∈ {0,1}, It’s the ground truth class
p ∈ [0,1], is the model’s estimated probability for the class with Y=1.
For notational convenience, we can rewrite the above equation as –
pt = {-ln(p) , when Y=1 -ln(1-p) ,when Y=}
CE (p, y) = CE (pt)=-ln(pt)
As you can see, the blue line in the below diagram, when p is very close to 0 (when Y =0) or 1, easily classified examples with large pt > 0.5 can incur a loss with non-trivial magnitude.
Fig: – The focal loss down weights easy examples with a factor of (1- pt)γ
Let’s understand it using an example below-
Let’s say, Foreground (Let’s call it class 1) is correctly classified with p=0.95 –
CE(FG) = -ln (0.95) =0.05
And background (Let’s call it class 0) is correctly classified with p=0.05 –
CE(BG)=-ln (1- 0.05) =0.05
The problem is, with the class imbalanced dataset, when these small losses are sum over the entire images can overwhelm the overall loss (total loss). And thus, it leads to degenerated models.
A common approach to addressing such a class imbalance problem is to introduce a weighting factor ∝∈[0,1] for class 1 & 1- for class -1.
For notational convenience, we can define ∝t in loss function as follows-
CE (pt )= -∝tln ln( pt )
As you can see, this is just an extension to Cross-Entropy.
As our experiments will show, the large class imbalance encountered during the training of dense detectors overwhelms the cross-entropy loss.
Easily classified negatives comprise the majority of the loss and dominate the gradient. While balances the importance of positive/negative examples, it does not differentiate between easy/hard examples.
Let’s understand this with an example-
Let’s say, Foreground (Let’s call it class 1) is correctly classified with p=0.95 –
CE(FG) = -0.25*ln (0.95) =0.0128
And background (Let’s call it class 0) correctly classified with p=0.05 –
CE(BG)=-(1-0.25) * ln (1- 0.05) =0.038
While it does a good job differentiating positive & negative classes correctly but still does not differentiate between easy/hard examples.
And that’s where Focal loss (extension to cross-entropy) comes to rescue.
Focal loss is just an extension of the cross-entropy loss function that would down-weight easy examples and focus training on hard negatives.
So to achieve this, researchers have proposed:
(1- pt)γ to the cross-entropy loss, with a tunable focusing parameter γ≥0.
RetinaNet object detection method uses an α-balanced variant of the focal loss, where α=0.25, γ=2 works the best.
So focal loss can be defined as –
FL (pt) = -αt(1- pt)γ log log(pt).
The focal loss is visualized for several values of γ∈[0,5], refer Figure 1.
We shall note the following properties of the focal loss-
As is increased, the effect of modulating factor is likewise increased. (After a lot of experiments and trials, researchers have found γ = 2 to work best)
Note:- when γ =0, FL is equivalent to CE. Shown blue curve in Fig
Intuitively, the modulating factor reduces the loss contribution from easy examples and extends the range in which an example receives the low loss.
Let’s understand the above properties of focal loss using an example-
Now let’s compare Cross-Entropy and Focal loss using a few examples and see the impact of focal loss in the training process.
Let’s see the comparison by considering a few scenarios below-
scenario-1: 0.05129/3.2058*10-7 = 1600 times smaller number
scenario-2: 2.3/0.667 = 4.5 times smaller number
scenario-3: 0.01/0.00000025 = 40,000 times smaller number.
These three cases clearly explain how Focal loss adds down weights the well-classified records and on the other hand, assigns large weight to misclassified or hard classified records.
After a lot of trials and experiments, researchers have found ∝=0.25 & γ=2 to work best.
We went through the complete journey of evolution of cross-entropy loss to a focal loss in object detection. I’ve tried my bit to explain the focal loss in object detection as simple as possible. Please feel free to comment on your queries. I’ll be more than happy to answer them.
If you enjoyed this article, leave a few claps, it will encourage me to explore more machine learning techniques & pen them down 🙂
Happy learning. Cheers!!
https://arxiv.org/ftp/arxiv/papers/2006/2006.01413.pdf
https://medium.com/@14prakash/the-intuition-behind-retinanet-eb636755607d
https://developers.arcgis.com/python/guide/how-retinanet-works/
Praveen Kumar Anwla
I’ve been working as a Data Scientist with product-based and Big 4 Audit firms for almost 5 years now. I have been working on various NLP, Machine learning & cutting edge deep learning frameworks to solve business problems. Please feel free to check out my personal blog, where I cover topics from Machine learning – AI, Chatbots to Visualization tools ( Tableau, QlikView, etc.) & various cloud platforms like Azure, IBM & AWS cloud.