Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Loss Functions in Torchvision #2980

Open
7 of 20 tasks
oke-aditya opened this issue Nov 10, 2020 · 29 comments
Open
7 of 20 tasks

[RFC] Loss Functions in Torchvision #2980

oke-aditya opened this issue Nov 10, 2020 · 29 comments

Comments

@oke-aditya
Copy link
Contributor

oke-aditya commented Nov 10, 2020

🚀 Feature

A loss functions API in torchvision.

Motivation

The request is simple, we have loss functions available in torchvision
E.g. sigmoid_focal_loss , l1_loss. But these are quite scattered and we have to use torchvision.ops.sigmoid_focal_loss etc.

In future, we might need to include further loss functions. E.g. dice_loss

Since loss functions are differentiable we can put them under nn.
We can have

torchvision.nn.losses.sigmoid_focal_loss and so on.

This keeps the scope of nn open for other differentiable functions such as layers, etc.

Pitch

These losses are very specific and pertain to vision domain. These are really useful and in general not tied to any specific model.
Though the loss functions that we keep are usually in torch. If we keep under nn namespace, future migration stays simple.

instead of torchvision.nn.sigmoid_focal_loss it would be torch.nn.sigmoid_focal_loss.

This Pitch comes from the above issues.
More Loss Functions

Alternatives

Alternatively, this should go in torch. But if we keep the above idea, we can support them in torchvision and later deprecate and move to torch (when needed).

Currently, we include them under ops but it is actually not an operation it is a differentiable loss function.

Whereas other ops are not differentiable and perform transformations / some manipulation over boxes/layers.

Additional context

Here is a list of loss functions we would like to include.

References

We can refer to Kornia, Fvcore and few PyTorch issues that need this feature.

@mthrok
Copy link
Contributor

mthrok commented Nov 12, 2020

cc. @dongreenberg @cpuhrsch I think this is very reasonable, there is a similar work going for torchaudio too. We should resume the library naming convention discussion and wrap it up to provide a comprehensive solution for loss/metrics.

@mthrok
Copy link
Contributor

mthrok commented Dec 4, 2020

@oke-aditya
The domain team had a brief discussion on this;

  • we agree that domain specific loss functions are coming up.
  • But we would like to hold off on creating a dedicated module untill we actually have a good number of functions that fall into the category. It's easy to add such a module, but once added we can not remove it.
  • meanwhile we can update the documentation and add a new category so that the existing loss is easy to find.

@oke-aditya What do you think?

cc @fmassa

@oke-aditya
Copy link
Contributor Author

I agree with your thoughts @mthrok . It can be too early to call for such API.
Yes it will be nice to update documentation.

@datumbox
Copy link
Contributor

datumbox commented Jan 3, 2021

I have a use-case that requires LabelSmoothing. Unfortunately CrossEntropyLoss does not support it in PyTorch (pytorch/pytorch#7455). This is a highly requested feature but unfortunately it's been blocked for more than 2 years. Thus I'm tempted to add it on TorchVision side until the above is resolved, but as @oke-aditya pointed out there is no great place to put it.

@oke-aditya It might be worth keeping track of the losses requested to be added here, so that we can see if we have a critical mass to move this forward. Would you be able to update the ticket description with the list of the current loss functions we want to add on the domain side?

@oke-aditya
Copy link
Contributor Author

Great point @datumbox
Sure 😄 I will update the Issue description.
Let's keep this issue for tracking purpose. Feel free to modify it if I miss something.

@yassineAlouini
Copy link
Contributor

I guess this issue still needs discussion and there is no point in wanting to contribute a loss for now? 🤔

@datumbox
Copy link
Contributor

@yassineAlouini Wow what a coincidence! Today I was working on something related. :)

At #5444, I have an experimental private function that makes it possible to switch between losses. There are no plans for it to become public any time soon but I was thinking of implementing the Distance-IoU & Complete-IoU losses listed on the ticket.

If you are interested in contributing them, let me know.

@oke-aditya
Copy link
Contributor Author

Adding cIoU and dIoU should be staight forward. It's been a while in my mind too.

Let me know if you need help. Or I can pick it up as well :) @yassineAlouini

@datumbox
Copy link
Contributor

@oke-aditya @yassineAlouini It would be awesome if you could help on the development and review of these 2 losses. For now will put them flat on the ops package similar to giou.

@oke-aditya
Copy link
Contributor Author

Sure 😃

@yassineAlouini
Copy link
Contributor

yassineAlouini commented Mar 30, 2022

Yes, it works for me, thanks @datumbox. 👌
Which one should I pick?
Do you have a preference @oke-aditya?

@oke-aditya
Copy link
Contributor Author

Pick Anything you like :)

@yassineAlouini
Copy link
Contributor

Thanks @oke-aditya. Let me give dIoU a try and see if I can also do the cIoU next. I guess you can help me with review since you know better this part of the repo. I can work on this around 1 day per week. 👌

@abhi-glitchhg
Copy link
Contributor

@oke-aditya, Have you started working on CioU? If not can I take CioU?
Thanks.

@oke-aditya
Copy link
Contributor Author

Sure @abhi-glitchhg feel free to take it. I'm happy reviewing the PR.

@oke-aditya
Copy link
Contributor Author

In case you didn't know here is detectron 2 implementation of both of these
https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py#L66

@yassineAlouini
Copy link
Contributor

Logistics question: Should I create an issue and add some details or is it enough to start a branch from my work and then do a PR later? @datumbox

@oke-aditya
Copy link
Contributor Author

This issue tracks it, You can create a fresh branch from main and raise a PR :)

@datumbox
Copy link
Contributor

datumbox commented Mar 31, 2022

@yassineAlouini What Aditya said ^. :)

No need for a separate ticket, we got plenty that mention it already. When you bring the PR, I'll tag it accordingly. Just make sure you mark is as draft until you are ready for review.

@abhi-glitchhg
Copy link
Contributor

abhi-glitchhg commented Apr 5, 2022

Should I also add a test for CIOU loss?

I tried finding a test for generalised iou loss in test_ops.py but did not find any. So just want to confirm.

@oke-aditya
Copy link
Contributor Author

Yeah the tests are not present as of yet see #5688. We can add it along with the PR, mostly you can check cases such as overlapping boxes, side by side boxes, etc.
Something on lines of https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py

But first we could have look at the implementation?

@yassineAlouini
Copy link
Contributor

@oke-aditya @abhi-glitchhg I think we need to keep our codes synched otherwise we might end up implementing something that is quite different for cIoU and dIoU. I will try to push a draft of my MR as soon as possible (around the end of this week or start of next week).

@oke-aditya
Copy link
Contributor Author

hey don't worry @yassineAlouini code reviews will make sure we are consistent :) We will sync it up. Feel free to work independently.

@datumbox
Copy link
Contributor

datumbox commented Apr 5, 2022

Agreed. Might be worth to start with the implementations and by then the test should be in. I'll keep an eye for your PRs as I'm currently keeping track of all related work at #5410.

@yassineAlouini
Copy link
Contributor

Some progress here (it is still a draft): #5786

@oke-aditya
Copy link
Contributor Author

Seems variety of IoU and it's losses keep evolving. Now after g d and c IoU we have sIoU

https://arxiv.org/abs/2205.12740

(How many alphabets will IoU get 4/26 currently 😁😁)

@abhi-glitchhg
Copy link
Contributor

I think SSIM also good candidate here.

And there exists a issue in the pytorch repo. - pytorch/pytorch#6934

Any thoughts?
@datumbox @oke-aditya

@oke-aditya
Copy link
Contributor Author

Yes but we don't have any task or usage in torchvision for SSIM.

@AngledLuffa
Copy link

Rather than opening a new issue about focal loss, I figured it might be simplest to comment here. Is there a timeline for reorganizing sigmoid_focal_loss and/or upgrading it to multiclass? It would also be useful to have it as a subclass of _WeightedLoss from torch.nn.modules. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants