Skip to content

Commit

Permalink
Applied the suggestions
Browse files Browse the repository at this point in the history
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
  • Loading branch information
shikhargoswami and CarloLucibello authored Feb 5, 2021
1 parent 1987693 commit 63e4d98
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ true
See also: [`Losses.focal_loss`](@ref) for multi-class setting
"""
function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
=.+ ϵ
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
ce = -log.(p_t)
Expand Down Expand Up @@ -501,11 +501,10 @@ true
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
"""
function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
=.+ ϵ
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
end
```@meta
DocTestFilters = nothing
```

0 comments on commit 63e4d98

Please sign in to comment.