-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Focal loss #1489
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution! Left some small comments, but the approach looks good to me.
Rather than loop iteration, it might be faster to do array operations. That might be faster for when you actually have to materialize the result with bigger arrays |
That paper is cited enough that it may be worth having this loss in flux, although pytorch doesn't have it and tensorflow has it only as an addon |
Yeah having never heard of this loss before, I checked the paper. It's well-cited, so good enough for me to include in Flux. |
@darsnack @CarloLucibello Need help on converting the list compreshension to array operations. Any useful leads? |
@shikhargoswami You might take a look at how I didn't see a performance difference when testing, but I didn't use large array sizes. Even if there isn't a gap in performance, it would be good to use the same style as the other loss functions. |
assuming 0/1 you can write p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) So this is used only for binary classification? This should be mentioned |
eb3882d
to
ba4b299
Compare
I guess I only implemented Binary classification. Here are the changes made:
Check if it needs any other change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I would prefer the names focal_loss
/binary_focal_loss
to match crossentropy
/binarycrossentropy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions on the docstrings.
I would commit the changes that address the comments above. I think this is close to being ready for approval, but we'll need to come to a consensus on the numerical stability vs. performance issue. That will require input for other maintainers, so probably we won't be able to merge today. |
I'll also need to review the api before we make a final call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to go for me, but @DhairyaLGandhi will need to approve the final API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've added a few last thoughts, but this is looking good. We might want to review how we write our losses but that would be a more general shift, so keeping it consistent with the current seems sensible to me. Thanks again for the good work @shikhargoswami
I think it would be better to assume logits as inputs instead of probabilities. In model-zoo we use |
This goes back to the numerical stability question. We would need to rework the implementation to make the logits version stable. If we do switch to logits, maybe the naming should be |
the fact that we have both |
I like the possibility of making "logit" a kwarg, but I think it is better in a separate PR. I think for here, let's just decide which version (logit or not) we want, and we can keep it consistent with what we already have. |
I think the less magical the function, the better, and this makes certain contracts very implicit, like which function softmax we use. It goes against the self documenting code we expect from Flux. I'd need to see some very compelling reasons to do opaque looking keywords in Flux. It's one reason I feel like we should remove the |
let's keep it as it is then, it's already done and |
I agree the names... Need work |
@shikhargoswami that should fix it. Can you address the remaining comments? |
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait for CI to pass then I think this can merge.
@darsnack Thanks a lot for the help! |
No problem! Thank you for your contribution! |
bors r+ |
1489: Implementation of Focal loss r=darsnack a=shikhargoswami Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example. Used in single-shot object detection where the imbalance between the background class and other classes is extremely high. Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81) ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Shikhar Goswami <shikhargoswami2308@gmail.com> Co-authored-by: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com>
bors r- |
1 similar comment
bors r- |
Canceled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shikhargoswami sorry to delay, but I realized you didn't add the focal loss docstring to the actual docs in docs/models/losses.md
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
bors r+ |
Build succeeded: |
Kyle, for the future, this is an API related change and the process does require a final approval from me. And good job @shikhargoswami ! |
Sorry, my mistake, I thought you had approved of the API. Next time I will leave the final merge to you on API changes. |
Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf).
Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example.
Used in single-shot object detection where the imbalance between the background class and other classes is extremely high.
Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81)
PR Checklist
@dhairyagandhi96
(for API changes).