-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
generalize and homogenize losses #1150
Conversation
fix #1024 |
@dhairyagandhi96 @MikeInnes this is ready for review |
The consistency is good. It doesn't seem necessary for a lot of these functions to have an |
Both solutions are ok, although I prefer to perform the reduction within the
|
I think the more julian approach might be On a side note, do we need to add the prefix |
src/layers/stateless.jl
Outdated
The `ϵ` term provides numerical stability. | ||
Penalizes an under-estimation more than an over-estimatation. | ||
""" | ||
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more accurate:
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) | |
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ)).^2) |
and faster (only one log
call).
For example,
julia> log(1e9) - log(1e9 + 1)
-1.000000082740371e-9
julia> log(1e9/(1 + 1e9))
-9.999999722180686e-10
The first form has an error of ~1e-16
while the later error is 1e-17
. (Exact answer is 9.99999999500000000333333333083e-10)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
I think it would, but see comment #1150 (comment) Aggregating should cover 99% of use cases, so I think it is worth doing it. It's no big deal, we could remove it, but I'm concerned about the large amount of breakage it would cause.
I think some functions like Then there is the case of @deprecate binarycrossentropy(yhat, y) binarycrossentropy(yhat, y, agg=identity) |
A minor point is that people really should use Maybe we can go through a deprecations path where we unify the two version using a keyword argument crossentropy(yhat, y, logits=true) # Flux v0.11
crossentropy(yhat, y) # Flux v0.12 |
There's a lot going on in this PR. Especially because of the breakage, I think it'd be good to narrow it down just to the bare minimum to fix these loss functions without eg other documentation fixes of movement of code.
I would be fine with adding a Perhaps it'd be useful to have a |
I don't see a clear cut distinction between core mathematical functions and loss functions. Losses are scalar-valued functions that happen to be commonly used as optimization objectives in machine learning. Sometimes they are explicitly devised for this purpose, sometimes not. Many losses, e.g. crossentropy, bce, poisson, and even MSE, can be interpreted as negative loglikelihoods, and the first two also as, well, crossentropies. Moreover, crossentopy and bce should really be treated consistently, they are the same thing applied to different distributions, they should both are aggregate or both not aggregate, now instead one is aggregating and the other not. Those functions are in flux uniquely because they are commonly used as losses, and we should not treat binarycrossentropy differently from mse or huber. We are not really planning for (binary)crossentropy to be something else, e.g. we don't support objects from Distributions.jl and probably we shouldn't. On the other hand, we should always offer the option for all losses to avoid aggregation for extra flexibility, and this is what this PR does as well.
crossentropy doesn't act on Distributions.jl objects, which would maybe be the closest approximation as a julia function of the mathematical definition. crossentropy acts on array representation of categorical distributions or batches of such. That the representation is given by
I agree, we should have a Losses submodule, I will add that. But we need to agree to what goes in and what stays out. In addition to what I said above, I'll mention that all of the functions I touched in this PR are categorized as losses in pytoch and keras In both frameworks, you have a class and a function counterpart for each loss. |
But there clearly is a distinction, and you've said so yourself. "Cross entropy" refers to two different things: (1) a function of two distributions (possibly represented as probability vectors or I don't think you're really saying there's no clear cut distinction, but that we just shouldn't bother trying to support the mathematical concepts at all, only the loss function versions – which is fine by me. I think we both agree that the current situation (mixing mathematical-like BCE and loss-like CE in one module) is not great. And we both seem to agree that it's a good idea to have a I personally think it'd be nice to share concepts with a generic implementation (perhaps Distrubitions.jl) in principle (eg "turn this distance metric into a loss function that supports batches and aggregation") but definitely don't feel that needs to block this cleanup. |
3c515c3
to
1d7a838
Compare
ok, we are on the same page. I'll merge this as it is since this PR is already quite involved, and create the |
bors r+ |
Build succeeded: |
1264: create Losses module r=CarloLucibello a=CarloLucibello Continuation of #1150, grouping losses under a `Losses` module as discussed. An alternative for the module name could be `Loss`, but `Losses` seemed more natural. I also renamed `bce` back to `binarycrossentropy`, since now that the function is within a module I could provide I deprecation path without changing the function name with respect to last tagged release. Some of the function contain the `_loss` suffix, (e.g. `hinge_loss`, `poisson_loss`). We could drop that now that we have a namespace disambiguating the meaning, but again it seems more natural to keep, closer to the way people referer to them when speaking ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes). Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
In order to enforce some consistency in the loss interface, this PR does the following:
agg
keyword, which defaults to the functionmean
. This defines the aggregation type (typicallymean
orsum
). One can useidentity
for no aggregation.dims
keyword when meaningful.For instance, the crossentropy definition becomes