diff --git a/NEWS.md b/NEWS.md index db7853995e..8001bf7c6e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.12.0 +* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module * The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405). * Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394). * Excise datasets in favour of other providers in the julia ecosystem. diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md index 5d8eb48700..d35ebd3673 100644 --- a/docs/src/models/losses.md +++ b/docs/src/models/losses.md @@ -39,4 +39,6 @@ Flux.Losses.hinge_loss Flux.Losses.squared_hinge_loss Flux.Losses.dice_coeff_loss Flux.Losses.tversky_loss +Flux.Losses.binary_focal_loss +Flux.Losses.focal_loss ``` diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index e781affaf6..bf944f9231 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -18,7 +18,8 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - ctc_loss + ctc_loss, + binary_focal_loss, focal_loss include("utils.jl") include("functions.jl") diff --git a/src/losses/functions.jl b/src/losses/functions.jl index d7fd666b63..20d45b7a80 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -429,7 +429,82 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) 1 - num / den end +""" + binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ)) + +Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). + +For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). + +# Example +```jldoctest +julia> y = [0 1 0 + 1 0 1] +2×3 Array{Int64,2}: + 0 1 0 + 1 0 1 + +julia> ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] +2×3 Array{Float64,2}: + 0.268941 0.5 0.268941 + 0.731059 0.5 0.731059 + +julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 +true +``` + +See also: [`Losses.focal_loss`](@ref) for multi-class setting + +""" +function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) + ŷ = ŷ .+ ϵ + p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) + ce = -log.(p_t) + weight = (1 .- p_t) .^ γ + loss = weight .* ce + agg(loss) +end + +""" + focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) +Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +which can be used in classification tasks with highly imbalanced classes. +It down-weights well-classified examples and focuses on hard examples. +The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output). + +The modulating factor, `γ`, controls the down-weighting strength. +For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). + +# Example +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Array{Int64,2}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) +3×5 Array{Float32,2}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 +true +``` + +See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels + +""" +function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) + ŷ = ŷ .+ ϵ + agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) +end ```@meta DocTestFilters = nothing ``` diff --git a/test/cuda/losses.jl b/test/cuda/losses.jl index 0913b0eb6a..a0f7f47d80 100644 --- a/test/cuda/losses.jl +++ b/test/cuda/losses.jl @@ -1,4 +1,4 @@ -using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy +using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss @testset "Losses" begin @@ -14,6 +14,17 @@ y = [1, 1, 0.] @test binarycrossentropy(σ.(x), y) ≈ binarycrossentropy(gpu(σ.(x)), gpu(y)) @test logitbinarycrossentropy(x, y) ≈ logitbinarycrossentropy(gpu(x), gpu(y)) +x = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] +y = [0 1 0 + 1 0 1] +@test binary_focal_loss(x, y) ≈ binary_focal_loss(gpu(x), gpu(y)) + +x = softmax(reshape(-7:7, 3, 5) .* 1f0) +y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) @testset "GPU grad tests" begin x = rand(Float32, 3,3) diff --git a/test/losses.jl b/test/losses.jl index 6f7a5c8407..9abc03abb8 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -13,7 +13,8 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.tversky_loss, Flux.Losses.dice_coeff_loss, Flux.Losses.poisson_loss, - Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss] + Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss] @testset "xlogx & xlogy" begin @@ -174,3 +175,34 @@ end end end end + +@testset "binary_focal_loss" begin + y = [0 1 0 + 1 0 1] + ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] + + y1 = [1 0 + 0 1] + ŷ1 = [0.6 0.3 + 0.4 0.7] + @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 + @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 + @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) +end + +@testset "focal_loss" begin + y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] + ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) + y1 = [1 0 + 0 0 + 0 1] + ŷ1 = [0.4 0.2 + 0.5 0.5 + 0.1 0.3] + @test Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628 + @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 + @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) +end