Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize and homogenize losses #1150

Merged
merged 5 commits into from
Jul 1, 2020
Merged

generalize and homogenize losses #1150

merged 5 commits into from
Jul 1, 2020

Conversation

CarloLucibello
Copy link
Member

In order to enforce some consistency in the loss interface, this PR does the following:

  • adds to every loss an agg keyword, which defaults to the function mean . This defines the aggregation type (typically mean or sum). One can use identity for no aggregation.
  • add a dims keyword when meaningful.
  • fix other little inconsistencies among the losses

For instance, the crossentropy definition becomes

function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)))
    agg(.-sum(y .* log.(ŷ .+ ϵ); dims=dims))
end

@CarloLucibello CarloLucibello changed the title generalize and homogenize losses [wip] generalize and homogenize losses Apr 27, 2020
@CarloLucibello CarloLucibello added this to the v0.11 milestone Apr 29, 2020
@CarloLucibello
Copy link
Member Author

fix #1024

@CarloLucibello CarloLucibello changed the title [wip] generalize and homogenize losses generalize and homogenize losses Apr 30, 2020
@CarloLucibello
Copy link
Member Author

@dhairyagandhi96 @MikeInnes this is ready for review

@MikeInnes
Copy link
Member

The consistency is good. It doesn't seem necessary for a lot of these functions to have an loss(..., agg = f) keyword when that could be written f(loss(...)), so if we want consistency perhaps we should just go that route.

@CarloLucibello
Copy link
Member Author

Both solutions are ok, although I prefer to perform the reduction within the loss (with agg=mean as default) for a few reasons:

  • mean is what you want to do in a large majority of use cases, so better avoid some visual clutter and point ML newcomers in the right direction
  • most losses currently perform reduction, so this PR causes little breakage
  • it is what other DL frameworks do

src/layers/stateless.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

I think the more julian approach might be agg(loss(...))?

On a side note, do we need to add the prefix _loss? Doesn't seem necessary.

The `ϵ` term provides numerical stability.
Penalizes an under-estimation more than an over-estimatation.
"""
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
Copy link
Contributor

@cossio cossio Jun 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more accurate:

Suggested change
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.((.+ ϵ) ./ (y .+ ϵ)).^2)

and faster (only one log call).

For example,

julia> log(1e9) - log(1e9 + 1)
-1.000000082740371e-9
julia> log(1e9/(1 + 1e9))
-9.999999722180686e-10

The first form has an error of ~1e-16 while the later error is 1e-17. (Exact answer is 9.99999999500000000333333333083e-10)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@CarloLucibello
Copy link
Member Author

I think the more julian approach might be agg(loss(...))?

I think it would, but see comment #1150 (comment)

Aggregating should cover 99% of use cases, so I think it is worth doing it. It's no big deal, we could remove it, but I'm concerned about the large amount of breakage it would cause.

On a side note, do we need to add the prefix _loss? Doesn't seem necessary.

I think some functions like poisson and hinge are just more clear and unambiguous with the _loss prefix.

Then there is the case of binarycrossentropy, which was one of the few cases where aggregation was not performed before, so I had to rename it to bce_loss to avoid breakage and give a gentler deprecation path. An alternative deprecation path would be to make the keyword agg mandatory just for binarycrossentropy and

@deprecate binarycrossentropy(yhat, y) binarycrossentropy(yhat, y, agg=identity)

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jun 9, 2020

A minor point is that people really should use logitcrossentropy and logitbinarycrossentropy (now logitbce_loss) for numerical stability, instead of crossentropy and binarycrossentropy (now bce_loss). Those names are so ugly though :(

Maybe we can go through a deprecations path where we unify the two version using a keyword argument logits. We would make it mandatory in v0.11, then have it default to true in v0.12. That is

crossentropy(yhat, y, logits=true)   # Flux v0.11
crossentropy(yhat, y)   # Flux v0.12

@MikeInnes
Copy link
Member

There's a lot going on in this PR. Especially because of the breakage, I think it'd be good to narrow it down just to the bare minimum to fix these loss functions without eg other documentation fixes of movement of code.

binarycrossentropy doesn't aggregate because it's not a loss function, it's an implementation of cross entropy for a Bernoulli variable, which you can use to build a loss function. Where possible I think it's useful to preserve a clear distinction between core mathematical functions and those that have Flux-specific behaviour (assuming a batch dimension, aggregating via mean etc).

I would be fine with adding a logits kwarg to crossentropy but would be pretty against changing its default behaviour, for much the same reasons: crossentropy names a mathematical function, rather than a Flux utility, and it'd be really confusing for it not to do what it says.

Perhaps it'd be useful to have a Loss module, so you can write eg Loss.poisson(...) and import the name directly if you want to. That'd solve the naming problem of adding _loss to everything, and it'd set a clear boundary for the distinction above. That direction is already breaking enough that removing aggregation-by-default would be reasonable (but I'm also not as strongly against it if there's a clear separation).

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jun 20, 2020

binarycrossentropy doesn't aggregate because it's not a loss function, it's an implementation of cross entropy for a Bernoulli variable, which you can use to build a loss function. Where possible I think it's useful to preserve a clear distinction between core mathematical functions and those that have Flux-specific behaviour (assuming a batch dimension, aggregating via mean etc)

I don't see a clear cut distinction between core mathematical functions and loss functions. Losses are scalar-valued functions that happen to be commonly used as optimization objectives in machine learning. Sometimes they are explicitly devised for this purpose, sometimes not. Many losses, e.g. crossentropy, bce, poisson, and even MSE, can be interpreted as negative loglikelihoods, and the first two also as, well, crossentropies.

Moreover, crossentopy and bce should really be treated consistently, they are the same thing applied to different distributions, they should both are aggregate or both not aggregate, now instead one is aggregating and the other not.

Those functions are in flux uniquely because they are commonly used as losses, and we should not treat binarycrossentropy differently from mse or huber. We are not really planning for (binary)crossentropy to be something else, e.g. we don't support objects from Distributions.jl and probably we shouldn't. On the other hand, we should always offer the option for all losses to avoid aggregation for extra flexibility, and this is what this PR does as well.

I would be fine with adding a logits kwarg to crossentropy but would be pretty against changing its default behaviour, for much the same reasons: crossentropy names a mathematical function, rather than a Flux utility, and it'd be really confusing for it not to do what it says.

crossentropy doesn't act on Distributions.jl objects, which would maybe be the closest approximation as a julia function of the mathematical definition. crossentropy acts on array representation of categorical distributions or batches of such. That the representation is given by ŷ[k] = pₖ or ŷ[k] = log pₖ is just an implementation choice, I don't see any confusion or any not doing what it says about that. I suggest to use the latter because people keep coming up with "why do I get NaNs?"

Perhaps it'd be useful to have a Loss module, so you can write eg Loss.poisson(...) and import the name directly if you want to. That'd solve the naming problem of adding _loss to everything, and it'd set a clear boundary for the distinction above. That direction is already breaking enough that removing aggregation-by-default would be reasonable (but I'm also not as strongly against it if there's a clear separation).

I agree, we should have a Losses submodule, I will add that. But we need to agree to what goes in and what stays out.

In addition to what I said above, I'll mention that all of the functions I touched in this PR are categorized as losses in pytoch and keras
https://keras.io/api/losses/
https://pytorch.org/docs/stable/nn.html#loss-functions
so I think we should do the same.

In both frameworks, you have a class and a function counterpart for each loss.
Pytorch aggregates over the batch by default in both cases, while Keras aggregates when using the class but it doesn't for the function.
I would go with aggregation by the default here unless we also want to introduce loss types (possibly useful for serialization) and do what Keras does.

@MikeInnes
Copy link
Member

But there clearly is a distinction, and you've said so yourself. "Cross entropy" refers to two different things: (1) a function of two distributions (possibly represented as probability vectors or Distributions objects) and (2) a loss function that accepts matrices representing batches of distributions and aggregates them. (The same could be said of 'euclidean distance' and 'mean squared error', but in that case the concepts helpfully have different names.) This PR actually makes this distinction clearer, by ensuring all loss functions have a common feature set – that they are all explicitly devised for purpose – which I think is a good thing.

I don't think you're really saying there's no clear cut distinction, but that we just shouldn't bother trying to support the mathematical concepts at all, only the loss function versions – which is fine by me. I think we both agree that the current situation (mixing mathematical-like BCE and loss-like CE in one module) is not great. And we both seem to agree that it's a good idea to have a Loss module to clearly signal that these are loss functions, rather than a generic implementations, so I think we're on the same page.

I personally think it'd be nice to share concepts with a generic implementation (perhaps Distrubitions.jl) in principle (eg "turn this distance metric into a loss function that supports batches and aggregation") but definitely don't feel that needs to block this cleanup.

@CarloLucibello CarloLucibello force-pushed the cl/loss branch 2 times, most recently from 3c515c3 to 1d7a838 Compare July 1, 2020 10:45
@CarloLucibello
Copy link
Member Author

ok, we are on the same page. I'll merge this as it is since this PR is already quite involved, and create the Loss module in a follow-up PR

@CarloLucibello
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Jul 1, 2020

Build succeeded:

@bors bors bot merged commit 822f13c into master Jul 1, 2020
bors bot added a commit that referenced this pull request Jul 9, 2020
1264: create Losses module r=CarloLucibello a=CarloLucibello

Continuation of #1150, grouping losses under a `Losses` module as discussed. 

An alternative for the module name could be `Loss`, but `Losses` seemed more natural.

I also renamed `bce` back to `binarycrossentropy`, since now that the function is within a module I could provide I deprecation path without changing the function name with respect to last tagged release.

Some of the function contain the  `_loss` suffix, (e.g. `hinge_loss`, `poisson_loss`). We could drop that now that we have a namespace disambiguating the meaning, but again it seems more natural to keep, closer to the way people referer to them when speaking

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes).


Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
@CarloLucibello CarloLucibello deleted the cl/loss branch January 7, 2021 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants