-
-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1150: generalize and homogenize losses r=CarloLucibello a=CarloLucibello In order to enforce some consistency in the loss interface, this PR does the following: - adds to every loss an `agg` keyword, which defaults to the function `mean` . This defines the aggregation type (typically `mean` or `sum`). One can use `identity` for no aggregation. - add a `dims` keyword when meaningful. - fix other little inconsistencies among the losses For instance, the crossentropy definition becomes ```julia function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ))) agg(.-sum(y .* log.(ŷ .+ ϵ); dims=dims)) end ``` Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
- Loading branch information
Showing
15 changed files
with
231 additions
and
234 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
## Loss Functions | ||
|
||
Flux provides a large number of common loss functions used for training machine learning models. | ||
|
||
Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ`. | ||
In Flux's convention, the order of the arguments is the following | ||
|
||
```julia | ||
loss(ŷ, y) | ||
``` | ||
|
||
Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the | ||
batch: | ||
|
||
```julia | ||
loss(ŷ, y) # defaults to `mean` | ||
loss(ŷ, y, agg=sum) # use `sum` for reduction | ||
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction | ||
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean | ||
loss(ŷ, y, agg=identity) # no aggregation. | ||
``` | ||
|
||
### Losses Reference | ||
|
||
```@docs | ||
Flux.mae | ||
Flux.mse | ||
Flux.msle | ||
Flux.huber_loss | ||
Flux.crossentropy | ||
Flux.logitcrossentropy | ||
Flux.bce_loss | ||
Flux.logitbce_loss | ||
Flux.kldivergence | ||
Flux.poisson_loss | ||
Flux.hinge_loss | ||
Flux.squared_hinge_loss | ||
Flux.dice_coeff_loss | ||
Flux.tversky_loss | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
@deprecate param(x) x | ||
@deprecate data(x) x | ||
# v0.11 deprecations | ||
@deprecate poisson poisson_loss | ||
@deprecate hinge hinge_loss | ||
@deprecate squared_hinge squared_hinge_loss | ||
@deprecate binarycrossentropy(ŷ, y) bce_loss(ŷ, y, agg=identity) | ||
@deprecate logitbinarycrossentropy(ŷ, y) logitbce_loss(ŷ, y, agg=identity) | ||
@deprecate normalise(x) normalise(x, dims=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.