What does Focal Loss Mean for Training Neural Networks
Facebook’s AI Research Lab published a paper on what they call Focal Loss, a new loss function for improving Average Precision (AP) in single-stage object detectors.
What’s important to know here is that object detectors are used in computer vision to distinguish specific objects in a visual scene. The two common methods to do this are single stage and two stage detectors. The main difference is the number of stages and processing time. One stage detectors are faster (and in machine learning its important to improve efficiency), but perform poorly compared to two stage detectors.
Two stage detectors propose a location for where an object is in an image then look for objects in the candidate boxes. These proposals might mean finding 1,000-2,000 candidate boxes before looking for the objects themselves. One stage detectors like YOLO are able to detect objects without first drawing creating proposals.
So imagine you’re an autonomous car. From your ocular sensor, you “see” something like this:
When you try to process what’s going on around you, you need to be aware of objects in the foreground, possibly within a few feet of you and objects in the background like street lights and tall buildings. Your ocular sensor’s code is trying to predict whether items are in the foreground or background. Then you can figure out if you have the possibility of hitting that object later on.
One major flaw with neural network architecture is that they tend to be overconfident in themselves.
FAIR proposes a new loss function that focuses a neural networks attention on the instances where it predicted the wrong class. Instead of trying to reduce outliers or predictions where the model’s prediction is far off from the truth, Focal Loss reduces the weight (or impact) the values it predicted correctly carry. The loss function is just a mathematical way of saying how far off a guess is from the real value of a datapoint. Frequently used in computer vision is the Cross Entropy Loss, which the Focal loss adds probability-dependent weight to CE loss.
It seems counterintuitive, but it’s actually a great method for improving neural network models. One of the major problems with AI is that bad models (ones that don’t predict correctly) are certain in their predictions while good models (ones that generalize to new data well) are uncertain in their predictions.
In the case of Focal Loss, it’s meant to help ease prediction when there’s a lot of data in one class and sparse data in another (foreground or background). Imagine if you were looking at something like this. How many objects are there in the foreground vs the background. Focal loss is helpful when there’s a ratio of [1:1000] or more for the imbalanced classes.
What happens if we feed this in to a regular single stage object detector? Often the network gets overwhelmed by the 1000s of objects in the background it has a hard time predicting the objects is in the foreground.
In training your object detector is looking through the 1000s of background objects and correctly identifies them as background object, but can incorrectly classify the foreground object and be like:
I did so well on thousands of these things, it’s not so bad I got every foreground object wrong, there were only a handful of those!
This is what focal loss solves. It puts more weight the objects that were hard to classify and decreases the impact on easy correct predictions. Mathematically a scaling factor is added to the cross entropy loss function. This number decreases (or decays towards 0) as the confidence in a prediction goes up.
Check out how RetinaNet (the method using focal loss) outperforms the current top-performing two-stage detectors.
The FAIR team tested their model on the COCO dataset, but I can see focal loss being helpful to other instances with highly imbalanced classes, so networks aren’t overconfident in their correct predictions on easy examples. As you can see from the chart above, this truly novel loss function helped RetinaNet archive a higher AP than both previous state-of-the-art one-stage and two-stage methods, while maintaining efficiency.
🤩 What does this mean for you?
If you’re training a model on image data (or any data) with highly imbalanced classes, this could be a credit card fraud, online conversion rates, or factory production defects, you can use focal loss with your cross entropy loss to train your model to the hard negative examples instead of overlooking that class and getting an arbitrarily high accuracy.
I’m excited to dive further into using this loss function and their open source code this weekend.
More on Object Detection: A Step-by-Step Introduction to the Basic Object Detection Algorithms
You can read the full paper here: Focal Loss for Dense Object Detection