Skip to content

Commit

Permalink
Merge #940
Browse files Browse the repository at this point in the history
940: Fix logitbinarycrossentropy on CuArrays r=MikeInnes a=matsueushi

The issue of logitbinarycrossentropy on GPU #464 can be also fixed by @janEbert's approach #926.

Co-authored-by: matsueushi <matsueushi@gmail.com>
  • Loading branch information
bors[bot] and matsueushi authored Nov 26, 2019
2 parents fbb377a + a0314ce commit 7c181fd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ but it is more numerically stable.
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)

# Re-definition to fix interaction with CuArrays.
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)

"""
normalise(x::AbstractArray; dims=1)
Expand Down
5 changes: 3 additions & 2 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ cx = gpu(x)
@test Flux.crossentropy(x,x, weight=1.0) Flux.crossentropy(cx,cx, weight=1.0)
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))

x = σ.([-1.1491, 0.8619, 0.3127])
x = [-1.1491, 0.8619, 0.3127]
y = [1, 1, 0.]
@test Flux.binarycrossentropy.(x,y) Flux.binarycrossentropy.(cu(x),cu(y))
@test Flux.binarycrossentropy.(σ.(x),y) Flux.binarycrossentropy.(cu(σ.(x)),cu(y))
@test Flux.logitbinarycrossentropy.(x,y) Flux.logitbinarycrossentropy.(cu(x),cu(y))

xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5)
Expand Down

0 comments on commit 7c181fd

Please sign in to comment.