Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Focal loss #1489

Merged
merged 14 commits into from
Feb 5, 2021
Merged

Implementation of Focal loss #1489

merged 14 commits into from
Feb 5, 2021

Conversation

shikhargoswami
Copy link
Contributor

@shikhargoswami shikhargoswami commented Jan 30, 2021

Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf).

Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example.

Used in single-shot object detection where the imbalance between the background class and other classes is extremely high.

Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81)

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @dhairyagandhi96 (for API changes).

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution! Left some small comments, but the approach looks good to me.

src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

Rather than loop iteration, it might be faster to do array operations. That might be faster for when you actually have to materialize the result with bigger arrays

@CarloLucibello
Copy link
Member

That paper is cited enough that it may be worth having this loss in flux, although pytorch doesn't have it and tensorflow has it only as an addon

@darsnack
Copy link
Member

Yeah having never heard of this loss before, I checked the paper. It's well-cited, so good enough for me to include in Flux.

@shikhargoswami
Copy link
Contributor Author

@darsnack @CarloLucibello Need help on converting the list compreshension to array operations. Any useful leads?

@darsnack
Copy link
Member

darsnack commented Jan 30, 2021

@shikhargoswami You might take a look at how binarycrossentropy is written to see how to address @DhairyaLGandhi's comment. The xlogy utility is what you need I believe.

I didn't see a performance difference when testing, but I didn't use large array sizes. Even if there isn't a gap in performance, it would be good to use the same style as the other loss functions.

@CarloLucibello
Copy link
Member

assuming 0/1 you can write

p_t = y .* ŷ  + (1 .- y) .* (1 .- ŷ)  

So this is used only for binary classification? This should be mentioned

test/losses.jl Outdated Show resolved Hide resolved
@shikhargoswami
Copy link
Contributor Author

I guess I only implemented Binary classification. Here are the changes made:

  • List comprehension -> Array operation
  • Implemented categorical_focal_loss and added it's test
  • Minor changes in docstring

Check if it needs any other change

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I would prefer the names focal_loss/binary_focal_loss to match crossentropy/binarycrossentropy.

src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/Losses.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions on the docstrings.

src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented Jan 31, 2021

I would commit the changes that address the comments above. I think this is close to being ready for approval, but we'll need to come to a consensus on the numerical stability vs. performance issue. That will require input for other maintainers, so probably we won't be able to merge today.

@DhairyaLGandhi
Copy link
Member

I'll also need to review the api before we make a final call

src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
test/losses.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to go for me, but @DhairyaLGandhi will need to approve the final API.

src/losses/functions.jl Outdated Show resolved Hide resolved
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added a few last thoughts, but this is looking good. We might want to review how we write our losses but that would be a more general shift, so keeping it consistent with the current seems sensible to me. Thanks again for the good work @shikhargoswami

src/losses/functions.jl Outdated Show resolved Hide resolved
src/losses/functions.jl Outdated Show resolved Hide resolved
test/losses.jl Show resolved Hide resolved
test/losses.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

I think it would be better to assume logits as inputs instead of probabilities. In model-zoo we use logitcrossentropy everyqhere instead of crossentropy for numerical stability

@darsnack
Copy link
Member

darsnack commented Feb 4, 2021

This goes back to the numerical stability question. We would need to rework the implementation to make the logits version stable. If we do switch to logits, maybe the naming should be logitfocalloss and logitbinaryfocalloss. It's tough on the eyes but consistent with the cross entropy convention.

@CarloLucibello
Copy link
Member

the fact that we have both crossentropy and logitcrossentropy is quite unfortunate, when I revisited the loss functions I thought about having a single definition function crossentropy(yhat, y; logits=true) but couldn't figure out a nice deprecation path. Working with unnormalized log-probabilities is what one should want in most cases for numerical stability. Should we experiment with this keyword arg approach here? Or we can just leave this as it is, not worth too much overthinking

@darsnack
Copy link
Member

darsnack commented Feb 4, 2021

I like the possibility of making "logit" a kwarg, but I think it is better in a separate PR. I think for here, let's just decide which version (logit or not) we want, and we can keep it consistent with what we already have.

@DhairyaLGandhi
Copy link
Member

I think the less magical the function, the better, and this makes certain contracts very implicit, like which function softmax we use. It goes against the self documenting code we expect from Flux. I'd need to see some very compelling reasons to do opaque looking keywords in Flux.

It's one reason I feel like we should remove the agg, it doesn't justify being in every loss function when it clearly doesn't add much to most of them.

@CarloLucibello
Copy link
Member

let's keep it as it is then, it's already done and logitbinaryfocalloss is quite horrifying 😄

@DhairyaLGandhi
Copy link
Member

I agree the names... Need work

@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

@shikhargoswami that should fix it. Can you address the remaining comments?

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait for CI to pass then I think this can merge.

@shikhargoswami
Copy link
Contributor Author

@darsnack Thanks a lot for the help!

@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

No problem! Thank you for your contribution!

@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

bors r+

bors bot added a commit that referenced this pull request Feb 5, 2021
1489: Implementation of Focal loss r=darsnack a=shikhargoswami

Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf). 

Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example. 

Used in single-shot object detection where the imbalance between the background class and other classes is extremely high.

Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81)
### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Shikhar Goswami <shikhargoswami2308@gmail.com>
Co-authored-by: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com>
@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

bors r-

1 similar comment
@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

bors r-

@bors
Copy link
Contributor

bors bot commented Feb 5, 2021

Canceled.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shikhargoswami sorry to delay, but I realized you didn't add the focal loss docstring to the actual docs in docs/models/losses.md.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@darsnack
Copy link
Member

darsnack commented Feb 5, 2021

bors r+

@bors
Copy link
Contributor

bors bot commented Feb 5, 2021

Build succeeded:

@bors bors bot merged commit d341500 into FluxML:master Feb 5, 2021
@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Feb 6, 2021

Kyle, for the future, this is an API related change and the process does require a final approval from me.

And good job @shikhargoswami !

@darsnack
Copy link
Member

darsnack commented Feb 6, 2021

Sorry, my mistake, I thought you had approved of the API. Next time I will leave the final merge to you on API changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants