From 27a9a7c33b9f04637ce037342929e43c203d2fb1 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 27 Apr 2020 11:44:16 +0200 Subject: [PATCH 1/5] new loss interface --- NEWS.md | 23 ++- docs/make.jl | 3 +- docs/src/models/layers.md | 20 +-- docs/src/models/losses.md | 40 +++++ docs/src/models/regularisation.md | 17 +- src/Flux.jl | 1 + src/deprecations.jl | 9 +- src/layers/losses.jl | 0 src/layers/stateless.jl | 270 ++++++++++++++++-------------- src/utils.jl | 3 + test/cuda/cuda.jl | 4 +- test/cuda/layers.jl | 7 +- test/layers/stateless.jl | 48 ++---- 13 files changed, 248 insertions(+), 197 deletions(-) create mode 100644 docs/src/models/losses.md create mode 100644 src/layers/losses.jl diff --git a/NEWS.md b/NEWS.md index 3b95b70c31..4151b2d955 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,12 +1,11 @@ # v0.11 -* Add [kaiming initialization](https://arxiv.org/abs/1502.01852) methods: `kaiming_uniform` and `kaiming_normal` [https://github.com/FluxML/Flux.jl/pull/1243] -* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152] -* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221]. -* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218]. -* Add `Adaptive Pooling` in Flux layers [https://github.com/FluxML/Flux.jl/pull/1239]. -* Optimistic ADAM (OADAM) optimizer for adversarial training [https://github.com/FluxML/Flux.jl/pull/1246]. - -# v0.10.5 +* Add [kaiming initialization](https://arxiv.org/abs/1502.01852) methods: [kaiming_uniform and kaiming_normal](https://github.com/FluxML/Flux.jl/pull/1243) +* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed [by name](https://github.com/FluxML/Flux.jl/pull/1221). +* Error if Dense layers weights and biases are [not arrays](https://github.com/FluxML/Flux.jl/pull/1218). +* Add (Adaptive Pooling)[https://github.com/FluxML/Flux.jl/pull/1239] in Flux layers. +* Change to `DataLoader`'s [constructor](https://github.com/FluxML/Flux.jl/pull/1152) +* Uniform loss [interface](https://github.com/FluxML/Flux.jl/pull/1150) +* Optimistic ADAM (OADAM) optimizer for [adversarial training](https://github.com/FluxML/Flux.jl/pull/1246). * Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`. * Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained. * Added `GlobalMaxPool` and `GlobalMeanPool` [layers](https://github.com/FluxML/Flux.jl/pull/950) for performing global pooling operations. @@ -16,14 +15,19 @@ * Testing suite improvements now test for gradients of all layers along with GPU support. * Functors have now moved to [Functors.jl](https://github.com/FluxML/Flux.jl/pull/1174) to allow for their use outside of Flux. * Added [helper functions](https://github.com/FluxML/Flux.jl/pull/873) `Flux.convfilter` and `Flux.depthwiseconvfilter` to construct weight arrays for convolutions outside of layer constructors so as to not have to depend on the default layers for custom implementations. +* and many more fixes and additions... + +# v0.10.1 - v0.10.4 +See GitHub's releases. # v0.10.0 + * The default AD engine has switched from [Tracker to Zygote.jl](https://github.com/FluxML/Flux.jl/pull/669) - The dependency on Tracker.jl has been removed. - This means Flux now does not depend on using a specialised `TrackedArray` type, and can be used with normal Array implementations directly. - Tracker compatibility is maintained in most common cases, but Zygote will be the preferred AD backend for Flux from now on. * The CUDNN wrappers have been [moved from Flux into CuArrays](https://github.com/FluxML/Flux.jl/pull/874), to allow for better supporting the CUDA backend, and improve user experience, not to mention making Flux lean. -* `*crossentropy` functions now [work as expected with CuArrays](https://github.com/FluxML/Flux.jl/pull/926). [PR for binarycrossentropy](https://github.com/FluxML/Flux.jl/pull/940). +* `*crossentropy` functions now [work as expected with CuArrays](https://github.com/FluxML/Flux.jl/pull/926). [PR for bce_loss](https://github.com/FluxML/Flux.jl/pull/940). * Added [clearer docs](https://github.com/FluxML/Flux.jl/pull/904) around training and the Optimiser interface. * [Layer initialisations](https://github.com/FluxML/Flux.jl/pull/937) have been improved with a clearer API on how to extend it for other purposes. * [Better messaging around CUDA availability](https://github.com/FluxML/Flux.jl/pull/924), with hooks to initialize the GPU as default where possible. @@ -31,6 +35,7 @@ * `testmode!` is deprecated in favour of [istraining](https://github.com/FluxML/Flux.jl/pull/669) # v0.9.0 + * [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor. * New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures. * New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser. diff --git a/docs/make.jl b/docs/make.jl index 2f24a022b9..b4e1b8b0b0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,8 +8,9 @@ makedocs(modules=[Flux, NNlib], "Building Models" => ["Basics" => "models/basics.md", "Recurrence" => "models/recurrence.md", - "Regularisation" => "models/regularisation.md", "Model Reference" => "models/layers.md", + "Loss Functions" => "models/losses.md", + "Regularisation" => "models/regularisation.md", "Advanced Model Building" => "models/advanced.md", "NNlib" => "models/nnlib.md"], "Handling Data" => diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 87178536ae..5f2cdbcfd1 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -73,22 +73,4 @@ Many normalisation layers behave differently under training and inference (testi ```@docs Flux.testmode! trainmode! -``` - -## Cost Functions -```@docs -Flux.mae -Flux.mse -Flux.msle -Flux.huber_loss -Flux.crossentropy -Flux.logitcrossentropy -Flux.binarycrossentropy -Flux.logitbinarycrossentropy -Flux.kldivergence -Flux.poisson -Flux.hinge -Flux.squared_hinge -Flux.dice_coeff_loss -Flux.tversky_loss -``` +``` \ No newline at end of file diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md new file mode 100644 index 0000000000..5333efa351 --- /dev/null +++ b/docs/src/models/losses.md @@ -0,0 +1,40 @@ +## Loss Functions + +Flux provides a large number of common loss functions used for training machine learning models. + +Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ`. +In Flux's convention, the order of the arguments is the following + +```julia +loss(ŷ, y) +``` + +Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the +batch: + +```julia +loss(ŷ, y) # defaults to `mean` +loss(ŷ, y, agg=sum) # use `sum` for reduction +loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction +loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean +loss(ŷ, y, agg=identity) # no aggregation. +``` + +### Losses Reference + +```@docs +Flux.mae +Flux.mse +Flux.msle +Flux.huber_loss +Flux.crossentropy +Flux.logitcrossentropy +Flux.bce_loss +Flux.logitbce_loss +Flux.kldivergence +Flux.poisson_loss +Flux.hinge_loss +Flux.squared_hinge_loss +Flux.dice_coeff_loss +Flux.tversky_loss +``` diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index 535dd09695..ee4350f0b6 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -7,9 +7,10 @@ add the result to the overall loss. For example, say we have a simple regression. ```julia -using Flux: crossentropy +using Flux +using Flux: logitcrossentropy m = Dense(10, 5) -loss(x, y) = crossentropy(softmax(m(x)), y) +loss(x, y) = logitcrossentropy(m(x), y) ``` We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. @@ -18,19 +19,19 @@ We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b using LinearAlgebra penalty() = norm(m.W) + norm(m.b) -loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() +loss(x, y) = logitcrossentropy(m(x), y) + penalty() ``` When working with layers, Flux provides the `params` function to grab all -parameters at once. We can easily penalise everything with `sum(norm, params)`. +parameters at once. We can easily penalise everything with `sum`: ```julia -julia> params(m) +julia> Flux.params(m) 2-element Array{Any,1}: param([0.355408 0.533092; … 0.430459 0.171498]) param([0.0, 0.0, 0.0, 0.0, 0.0]) -julia> sum(norm, params(m)) +julia> sum(norm, Flux.params(m)) 26.01749952921026 ``` @@ -40,9 +41,9 @@ Here's a larger example with a multi-layer perceptron. m = Chain( Dense(28^2, 128, relu), Dense(128, 32, relu), - Dense(32, 10), softmax) + Dense(32, 10)) -loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m)) +loss(x, y) = logitcrossentropy(m(x), y) + sum(norm, Flux.params(m)) loss(rand(28^2), rand(10)) ``` diff --git a/src/Flux.jl b/src/Flux.jl index e40f9ea7ce..3ab7be34e5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -35,6 +35,7 @@ include("onehot.jl") include("functor.jl") include("layers/stateless.jl") +include("layers/losses.jl") include("layers/basic.jl") include("layers/conv.jl") include("layers/recurrent.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index ccaac27aaf..3c8a5bc673 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,2 +1,7 @@ -@deprecate param(x) x -@deprecate data(x) x +# v0.11 deprecations +@deprecate poisson poisson_loss +@deprecate hinge hinge_loss +@deprecate squared_hinge squared_hinge_loss +@deprecate binarycrossentropy(ŷ, y) bce_loss(ŷ, y, agg=identity) +@deprecate logitbinarycrossentropy(ŷ, y) logitbce_loss(ŷ, y, agg=identity) +@deprecate normalise(x) normalise(x, dims=1) diff --git a/src/layers/losses.jl b/src/layers/losses.jl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 175ceafdc0..f652b912fa 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,43 +1,35 @@ -# Cost functions """ - mae(ŷ, y) + mae(ŷ, y; agg=mean) -Return the mean of absolute error; calculated as -`sum(abs.(ŷ .- y)) / length(y)`. -""" -mae(ŷ, y) = sum(abs.(ŷ .- y)) * 1 // length(y) +Return the loss corresponding to mean absolute error: + agg(abs.(ŷ .- y)) +""" +mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y)) """ - mse(ŷ, y) + mse(ŷ, y; agg=mean) -Return the mean squared error between ŷ and y; calculated as -`sum((ŷ .- y).^2) / length(y)`. +Return the loss corresponding to mean square error: + + agg((ŷ .- y).^2) +""" +mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2) -# Examples -```jldoctest -julia> Flux.mse([0, 2], [1, 1]) -1//1 -``` """ -mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y) + msle(ŷ, y; agg=mean, ϵ=eps(ŷ)) +The loss corresponding to mean squared logarithmic errors, calculated as -""" - msle(ŷ, y; ϵ=eps(eltype(ŷ))) + agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) -Return the mean of the squared logarithmic errors; calculated as -`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`. The `ϵ` term provides numerical stability. - -Penalizes an under-predicted estimate greater than an over-predicted estimate. +Penalizes an under-estimation more than an over-estimatation. """ -msle(ŷ, y; ϵ=eps(eltype(ŷ))) = sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) * 1 // length(y) - - +msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))).^2) """ - huber_loss(ŷ, y; δ=1.0) + huber_loss(ŷ, y; δ=1, agg=mean) Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) given the prediction `ŷ` and true values `y`. @@ -46,109 +38,89 @@ given the prediction `ŷ` and true values `y`. Huber loss = | | δ * (|ŷ - y| - 0.5 * δ), otherwise """ -#TODO: remove dropgrad when Zygote can handle this function with CuArrays -function huber_loss(ŷ, y; δ=eltype(ŷ)(1)) +function huber_loss(ŷ, y; agg=mean, δ=ofeltype(ŷ, 1)) abs_error = abs.(ŷ .- y) + #TODO: remove dropgrad when Zygote can handle this function with CuArrays temp = Zygote.dropgrad(abs_error .< δ) - x = eltype(ŷ)(0.5) - hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y) -end - -function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing) - return -sum(xlogy.(y, ŷ)) * 1 // size(y, 2) -end - -function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number) - return -sum(xlogy.(y, ŷ)) .* weight * 1 // size(y, 2) -end - -function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector) - return -sum(xlogy.(y, ŷ) .* weight) * 1 // size(y, 2) + x = ofeltype(ŷ, 0.5) + agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) end """ - crossentropy(ŷ, y; weight = nothing) + crossentropy(ŷ, y; dims=1, ϵ=eps(ŷ), agg=mean) Return the cross entropy between the given probability distributions; -calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`. +calculated as -`weight` can be `Nothing`, a `Number` or an `AbstractVector`. -`weight=nothing` acts like `weight=1` but is faster. + agg(-sum(y .* log.(ŷ .+ ϵ); dims=dims)) -See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) +Cross entropy is tipically used as a loss in multi-class classification, +in which case the labels `y` are given in a one-hot format. +`dims` specifies the dimension (or the dimensions) containing the class probabilities. +The prediction `ŷ` is supposed to sum to one across `dims`, +as would be the case with the output of a [`softmax`](@ref) operation. -# Examples -```jldoctest -julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0]) -3.085467254747739 -``` +Use of [`logitcrossentropy`](@ref) is recomended over `crossentropy` for +numerical stability. + +See also: [`Flux.logitcrossentropy`](@ref), [`Flux.bce_loss`](@ref), [`Flux.logitbce_loss`](@ref) """ -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight) +function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ)) + agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims)) +end """ - logitcrossentropy(ŷ, y; weight = 1) + logitcrossentropy(ŷ, y; dims=1, ϵ=eps(ŷ), agg=mean) Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation; -calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`. +calculated as + + agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims)) `logitcrossentropy(ŷ, y)` is mathematically equivalent to [`Flux.crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable. -See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) - -# Examples -```jldoctest -julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0]) -3.085467254747738 -``` +See also: [`Flux.crossentropy`](@ref), [`Flux.bce_loss`](@ref), [`Flux.logitbce_loss`](@ref) """ -function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) - return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2) +function logitcrossentropy(ŷ, y; dims=1, agg=mean) + agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims)) end """ - binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) + bce_loss(ŷ, y; agg=mean, ϵ=eps(ŷ)) -Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability. +Return the binary cross-entropy loss, computer as + + agg(@.(-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ))) + +The `ϵ` term provides numerical stability. Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation. -See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) +Use of `logitbce_loss` is recomended over `bce_loss` for numerical stability. -# Examples -```jldoctest -julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0]) -3-element Array{Float64,1}: - 1.424397097347566 - 0.35231664672364077 - 0.8616703662235441 -``` +See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbce_loss`](@ref) """ -binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ) - +function bce_loss(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) + agg(@.(-xlogy(y, ŷ+ϵ) - xlogy(1-y, 1-ŷ+ϵ))) +end # Re-definition to fix interaction with CuArrays. -CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) +# CuArrays.@cufunc bce_loss(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ - logitbinarycrossentropy(ŷ, y) - -`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to -[`Flux.binarycrossentropy(σ(ŷ), y)`](@ref) but it is more numerically stable. + logitbce_loss(ŷ, y; agg=mean) -See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref) +Mathematically equivalent to +[`Flux.bce_loss(σ(ŷ), y)`](@ref) but is more numerically stable. -# Examples -```jldoctest -julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0]) -3-element Array{Float64,1}: - 1.4243970973475661 - 0.35231664672364094 - 0.8616703662235443 +See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.bce_loss`](@ref) ``` """ -logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) - +function logitbce_loss(ŷ, y; agg=mean) + agg(@.((1-y)*ŷ - logσ(ŷ))) +end # Re-definition to fix interaction with CuArrays. +<<<<<<< HEAD CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) """ @@ -182,9 +154,13 @@ function normalise(x::AbstractArray; dims=1, ϵ=1.0f-5) σ′ = std(x, dims = dims, mean = μ′, corrected=false) return (x .- μ′) ./ (σ′ .+ ϵ) end +======= +# CuArrays.@cufunc logitbce_loss(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) + +>>>>>>> 210028b3... new loss interface """ - kldivergence(ŷ, y) + kldivergence(ŷ, y; agg=mean) Return the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) @@ -195,42 +171,42 @@ from the other. It is always non-negative and zero only when both the distributions are equal everywhere. """ -function kldivergence(ŷ, y) - entropy = sum(xlogx.(y)) * 1 //size(y,2) - cross_entropy = crossentropy(ŷ, y) +function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ)) + entropy = agg(sum(xlogx.(y), dims=dims)) + cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ) return entropy + cross_entropy end """ - poisson(ŷ, y) - -Return how much the predicted distribution `ŷ` diverges from the expected Poisson -distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. + poisson_loss(ŷ, y) -[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). +# Return how much the predicted distribution `ŷ` diverges from the expected Poisson +# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. +REDO +[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson_loss). """ -poisson(ŷ, y) = sum(ŷ .- xlogy.(y, ŷ)) * 1 // size(y,2) +poisson_loss(ŷ, y; agg=mean) = agg(ŷ .- xlogy.(y, ŷ)) """ - hinge(ŷ, y) + hinge_loss(ŷ, y; agg=mean) -Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the +Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as `sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`. -See also: [`squared_hinge`](@ref) +See also: [`squared_hinge_loss`](@ref) """ -hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) * 1 // size(y, 2) +hinge_loss(ŷ, y; agg=mean) = agg(max.(0, 1 .- ŷ .* y)) """ - squared_hinge(ŷ, y) + squared_hinge_loss(ŷ, y) -Return the squared hinge loss given the prediction `ŷ` and true labels `y` +Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`. -See also: [`hinge`](@ref) +See also: [`hinge_loss`](@ref) """ -squared_hinge(ŷ, y) = sum((max.(0, 1 .- ŷ .* y)).^2) * 1 // size(y, 2) +squared_hinge_loss(ŷ, y; agg=mean) = agg((max.(0, 1 .- ŷ .* y)).^2) """ dice_coeff_loss(ŷ, y; smooth=1) @@ -239,39 +215,37 @@ Return a loss based on the dice coefficient. Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation architecture. Similar to the F1_score. Calculated as: - 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)` + + 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth) """ -dice_coeff_loss(ŷ, y; smooth=eltype(ŷ)(1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth) +dice_coeff_loss(ŷ, y; smooth=ofeltype(ŷ, 1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth) #TODO agg """ tversky_loss(ŷ, y; β=0.7) Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf). Used with imbalanced data to give more weight to false negatives. -Larger β weigh recall higher than precision (by placing more emphasis on false negatives) +Larger β weigh recall more than precision (by placing more emphasis on false negatives) Calculated as: 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) """ -tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) - -""" - flatten(x::AbstractArray) - -Transform (w, h, c, b)-shaped input into (w × h × c, b)-shaped output -by linearizing all values for each element in the batch. -""" -function flatten(x::AbstractArray) - return reshape(x, :, size(x)[end]) +function tversky_loss(ŷ, y; β=ofeltype(ŷ, 0.7)) + #TODO add agg + num = sum(y .* ŷ) + 1 + den = sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1 + 1 - num / den end - + """ xlogx(x) + Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit. """ function xlogx(x) result = x * log(x) ifelse(iszero(x), zero(result), result) end + CuArrays.@cufunc function xlogx(x) result = x * log(x) ifelse(iszero(x), zero(result), result) @@ -279,12 +253,14 @@ end """ xlogy(x, y) + Return `x * log(y)` for `y > 0` with correct limit at `x = 0`. """ function xlogy(x, y) result = x * log(y) ifelse(iszero(x), zero(result), result) end + CuArrays.@cufunc function xlogy(x, y) result = x * log(y) ifelse(iszero(x), zero(result), result) @@ -294,3 +270,51 @@ end res = xlogy.(x, y) res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) end + + +""" + flatten(x::AbstractArray) + +Reshape arbitrarly-shaped input into a matrix-shaped output +preserving the last dimension size. +Equivalent to `reshape(x, :, size(x)[end])`. +""" +function flatten(x::AbstractArray) + return reshape(x, :, size(x)[end]) +end + +# TODO normalise over last dimension is typically what you want to do. +# Deprecation path: `normalise(x; dims=1)` -> `normalise(x; dims)` -> `normalise(x; dims=size(x)[end])` +""" + normalise(x; dims, ϵ=1e-6) + +Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`. +Defaults to normalising over columns. +`ϵ` is a small additive factor added to the denominator for numerical stability. + +```jldoctest +julia> a = reshape(collect(1:9), 3, 3) +3×3 Array{Int64,2}: + 1 4 7 + 2 5 8 + 3 6 9 + +julia> Flux.normalise(a, dims=1) +3×3 Array{Float64,2}: + -1.22474 -1.22474 -1.22474 + 0.0 0.0 0.0 + 1.22474 1.22474 1.22474 + +julia> Flux.normalise(a, dims=2) +3×3 Array{Float64,2}: + -1.22474 0.0 1.22474 + -1.22474 0.0 1.22474 + -1.22474 0.0 1.22474 +``` +""" +function normalise(x::AbstractArray; dims, ϵ=ofeltype(x, 1e-6)) + μ′ = mean(x, dims=dims) + # σ′ = std(x, dims=dims, mean=μ′, corrected=false) # use this when #478 gets merged + σ′ = std(x, dims=dims, corrected=false) + return (x .- μ′) ./ (σ′.+ ϵ) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 53a6a7aabe..9c2970233c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -32,6 +32,9 @@ nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices nfan(dims::Tuple) = nfan(dims...) nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels +ofeltype(x, y) = convert(float(eltype(x)), y) +epseltype(x) = eps(float(eltype(x))) + """ glorot_uniform(dims...) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 5bcfb3e7b4..28fc22ac52 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -33,8 +33,8 @@ cx = gpu(x) x = [-1.1491, 0.8619, 0.3127] y = [1, 1, 0.] -@test Flux.binarycrossentropy.(σ.(x),y) ≈ Array(Flux.binarycrossentropy.(cu(σ.(x)),cu(y))) -@test Flux.logitbinarycrossentropy.(x,y) ≈ Array(Flux.logitbinarycrossentropy.(cu(x),cu(y))) +@test Flux.bce_loss(σ.(x), y) ≈ Flux.bce_loss(cu(σ.(x)), cu(y)) +@test Flux.logitbce_loss(x, y) ≈ Flux.logitbce_loss(cu(x), cu(y)) xs = rand(5, 5) ys = Flux.onehotbatch(1:5,1:5) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index b51320ff4c..56b40c7b3b 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -70,10 +70,11 @@ gradtest("GroupNorm", groupnorm, rand(Float32, 28,28,3,1), 3, 1) const stateless_layers = [Flux.mse, Flux.crossentropy, Flux.logitcrossentropy, - Flux.normalise] + Flux.normalise, + Flux.bce_loss, + Flux.logitbce_loss] -const stateless_layers_broadcasted = [Flux.binarycrossentropy, - Flux.logitbinarycrossentropy] +const stateless_layers_broadcasted = [] function stateless_gradtest(f, args...) @test gradient((args...) -> sum(f(args...)), args...)[1] isa CuArray diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index a61e912a14..e9db04e1a0 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,6 +1,6 @@ using Test using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, - σ, binarycrossentropy, logitbinarycrossentropy, flatten, + σ, bce_loss, logitbce_loss, flatten, xlogx, xlogy const ϵ = 1e-7 @@ -60,26 +60,14 @@ end @test logitcrossentropy(logŷ, y) ≈ lossvalue end - @testset "weighted_crossentropy" begin - @test crossentropy(ŷ, y, weight = ones(2)) ≈ lossvalue - @test crossentropy(ŷ, y, weight = [.5, .5]) ≈ lossvalue/2 - @test crossentropy(ŷ, y, weight = [2, .5]) ≈ 1.5049660054074199 - end - - @testset "weighted_logitcrossentropy" begin - @test logitcrossentropy(logŷ, y, weight = ones(2)) ≈ lossvalue - @test logitcrossentropy(logŷ, y, weight = [.5, .5]) ≈ lossvalue/2 - @test logitcrossentropy(logŷ, y, weight = [2, .5]) ≈ 1.5049660054074199 - end - logŷ, y = randn(3), rand(3) - @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)) - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))) + @testset "bce_loss" begin + @test bce_loss(σ.(logŷ), y; ϵ=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) + @test bce_loss(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) end - @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) + @testset "logitbce_loss" begin + @test logitbce_loss(logŷ, y) ≈ bce_loss(σ.(logŷ), y; ϵ=0) end y = [1 2 3] @@ -92,21 +80,21 @@ end y = [1 2 3 4] ŷ = [5.0 6.0 7.0 8.0] - @testset "hinge" begin - @test Flux.hinge(ŷ, y) ≈ 0 - @test Flux.hinge(y, 0.5 .* y) ≈ 0.125 + @testset "hinge_loss" begin + @test Flux.hinge_loss(ŷ, y) ≈ 0 + @test Flux.hinge_loss(y, 0.5 .* y) ≈ 0.125 end - @testset "squared_hinge" begin - @test Flux.squared_hinge(ŷ, y) ≈ 0 - @test Flux.squared_hinge(y, 0.5 .* y) ≈ 0.0625 + @testset "squared_hinge_loss" begin + @test Flux.squared_hinge_loss(ŷ, y) ≈ 0 + @test Flux.squared_hinge_loss(y, 0.5 .* y) ≈ 0.0625 end y = [0.1 0.2 0.3] ŷ = [0.4 0.5 0.6] - @testset "poisson" begin - @test Flux.poisson(ŷ, y) ≈ 0.6278353988097339 - @test Flux.poisson(y, y) ≈ 0.5044459776946685 + @testset "poisson_loss" begin + @test Flux.poisson_loss(ŷ, y) ≈ 0.6278353988097339 + @test Flux.poisson_loss(y, y) ≈ 0.5044459776946685 end y = [1.0 0.5 0.3 2.4] @@ -118,7 +106,7 @@ end @testset "tversky_loss" begin @test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383 - @test Flux.tversky_loss(ŷ, y, β = 0.8) ≈ -0.09490740740740744 + @test Flux.tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 @test Flux.tversky_loss(y, y) ≈ -0.5576923076923075 end @@ -126,8 +114,8 @@ end for T in (Float32, Float64) y = rand(T, 2) ŷ = rand(T, 2) - for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge, Flux.poisson, - Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge, Flux.dice_coeff_loss, Flux.tversky_loss) + for f in (mse, crossentropy, logitcrossentropy, Flux.kldivergence, Flux.hinge_loss, Flux.poisson_loss, + Flux.mae, Flux.huber_loss, Flux.msle, Flux.squared_hinge_loss, Flux.dice_coeff_loss, Flux.tversky_loss) fwd, back = Flux.pullback(f, ŷ, y) @test fwd isa T @test eltype(back(one(T))[1]) == T From ddcecb301f20a35d96813dc08514d9c412dcd5ee Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 1 Jul 2020 12:24:37 +0200 Subject: [PATCH 2/5] cleanup --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 4151b2d955..25715afdec 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,5 @@ # v0.11 + * Add [kaiming initialization](https://arxiv.org/abs/1502.01852) methods: [kaiming_uniform and kaiming_normal](https://github.com/FluxML/Flux.jl/pull/1243) * Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed [by name](https://github.com/FluxML/Flux.jl/pull/1221). * Error if Dense layers weights and biases are [not arrays](https://github.com/FluxML/Flux.jl/pull/1218). @@ -18,6 +19,7 @@ * and many more fixes and additions... # v0.10.1 - v0.10.4 + See GitHub's releases. # v0.10.0 From 413c794f289717ce06362c73b15f49b1166c0823 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 1 Jul 2020 12:39:01 +0200 Subject: [PATCH 3/5] fix doctests --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7a7ce29f13..30358ac21c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -100,7 +100,7 @@ Dense(5, 2) julia> d(rand(5)) 2-element Array{Float32,1}: -0.16210233 - 0.12311903 + 0.123119034 ``` """ struct Dense{F,S<:AbstractArray,T<:AbstractArray} From c1da4b57f570c2fffb77fe8b461e63c610a48dda Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 1 Jul 2020 17:04:24 +0200 Subject: [PATCH 4/5] fix normalise tests --- src/layers/normalise.jl | 2 +- src/layers/stateless.jl | 33 ++++++--------------------------- test/cuda/cuda.jl | 7 ++++--- test/cuda/layers.jl | 2 +- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 0b5e04fb14..fe15f26b38 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -113,7 +113,7 @@ LayerNorm(h::Integer) = @functor LayerNorm -(a::LayerNorm)(x) = a.diag(normalise(x)) +(a::LayerNorm)(x) = a.diag(normalise(x, dims=1)) function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", length(l.diag.α), ")") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index f652b912fa..233384922c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -286,35 +286,14 @@ end # TODO normalise over last dimension is typically what you want to do. # Deprecation path: `normalise(x; dims=1)` -> `normalise(x; dims)` -> `normalise(x; dims=size(x)[end])` """ - normalise(x; dims, ϵ=1e-6) + normalise(x; dims, ϵ=1e-5) Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`. -Defaults to normalising over columns. `ϵ` is a small additive factor added to the denominator for numerical stability. - -```jldoctest -julia> a = reshape(collect(1:9), 3, 3) -3×3 Array{Int64,2}: - 1 4 7 - 2 5 8 - 3 6 9 - -julia> Flux.normalise(a, dims=1) -3×3 Array{Float64,2}: - -1.22474 -1.22474 -1.22474 - 0.0 0.0 0.0 - 1.22474 1.22474 1.22474 - -julia> Flux.normalise(a, dims=2) -3×3 Array{Float64,2}: - -1.22474 0.0 1.22474 - -1.22474 0.0 1.22474 - -1.22474 0.0 1.22474 -``` """ -function normalise(x::AbstractArray; dims, ϵ=ofeltype(x, 1e-6)) - μ′ = mean(x, dims=dims) - # σ′ = std(x, dims=dims, mean=μ′, corrected=false) # use this when #478 gets merged - σ′ = std(x, dims=dims, corrected=false) - return (x .- μ′) ./ (σ′.+ ϵ) +function normalise(x::AbstractArray; dims, ϵ=ofeltype(x, 1e-5)) + μ = mean(x, dims=dims) + # σ = std(x, dims=dims, mean=μ, corrected=false) # use this when #478 gets merged + σ = std(x, dims=dims, corrected=false) + return (x .- μ) ./ (σ .+ ϵ) end \ No newline at end of file diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 28fc22ac52..83cdf10222 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -1,6 +1,7 @@ using Flux, Test using Flux.CuArrays using Flux: gpu +using Statistics: mean @info "Testing GPU Support" @@ -27,9 +28,9 @@ cm = gpu(m) x = [1.,2.,3.] cx = gpu(x) -@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx) -@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0) -@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0])) +@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx) +@test Flux.crossentropy(x,x, agg=identity) ≈ Flux.crossentropy(cx,cx, agg=identity) |> cpu +@test Flux.crossentropy(x,x, agg=x->mean([1.0;2.0;3.0].*x)) ≈ Flux.crossentropy(cx,cx, agg=x->mean(cu([1.0;2.0;3.0]).*x)) x = [-1.1491, 0.8619, 0.3127] y = [1, 1, 0.] diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 56b40c7b3b..4e8fbd41df 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -90,7 +90,7 @@ end for layer in stateless_layers if layer == Flux.normalise - stateless_gradtest(layer, x) + stateless_gradtest(x -> layer(x, dims=1), x) else stateless_gradtest(layer, x, y) end From b81552aea7ca66f935a30cd411c67a2210df26e9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 1 Jul 2020 17:13:12 +0200 Subject: [PATCH 5/5] cleanup --- src/layers/stateless.jl | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 233384922c..ff47d58ec3 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -120,44 +120,8 @@ function logitbce_loss(ŷ, y; agg=mean) agg(@.((1-y)*ŷ - logσ(ŷ))) end # Re-definition to fix interaction with CuArrays. -<<<<<<< HEAD -CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) - -""" - normalise(x; dims=1, ϵ=1.0e-5)) - -Normalise `x` to mean 0 and standard deviation 1 across the dimensions given by `dims`. -Defaults to normalising over columns. Regularizes the standard deviation by `ϵ`. - -```jldoctest -julia> a = reshape(collect(1:9), 3, 3) -3×3 Array{Int64,2}: - 1 4 7 - 2 5 8 - 3 6 9 - -julia> Flux.normalise(a) -3×3 Array{Float64,2}: - -1.22473 -1.22473 -1.22473 - 0.0 0.0 0.0 - 1.22473 1.22473 1.22473 - -julia> Flux.normalise(a, dims=2) -3×3 Array{Float64,2}: - -1.22474 0.0 1.22474 - -1.22474 0.0 1.22474 - -1.22474 0.0 1.22474 -``` -""" -function normalise(x::AbstractArray; dims=1, ϵ=1.0f-5) - μ′ = mean(x, dims = dims) - σ′ = std(x, dims = dims, mean = μ′, corrected=false) - return (x .- μ′) ./ (σ′ .+ ϵ) -end -======= # CuArrays.@cufunc logitbce_loss(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) ->>>>>>> 210028b3... new loss interface """ kldivergence(ŷ, y; agg=mean)