Skip to content

Commit

Permalink
Merge #1129
Browse files Browse the repository at this point in the history
1129: Added dropgrad in huber_loss r=CarloLucibello a=HenriDeh

Workaround to prevent `iterate(::nothing)` when working with CuArrays. See issue #1128

Co-authored-by: HenriDeh <47037088+HenriDeh@users.noreply.github.com>
  • Loading branch information
bors[bot] and HenriDeh authored Jun 6, 2020
2 parents 9ebbe8c + ac94754 commit d9b0747
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ given the prediction `ŷ` and true values `y`.
Huber loss = |
| δ * (|ŷ - y| - 0.5 * δ), otherwise
"""
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
abs_error = abs.(ŷ .- y)
temp = abs_error .< δ
temp = Zygote.dropgrad(abs_error .< δ)
x = eltype(ŷ)(0.5)
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
end
Expand Down

0 comments on commit d9b0747

Please sign in to comment.