diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 4e7a3a1cc..3affebd92 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -72,7 +72,9 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) @adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ)) + Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ))) +_minus(Δ) = -Δ +_minus(::Nothing) = nothing @adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x))) diff --git a/test/features.jl b/test/features.jl index d683d0d94..3115a455c 100644 --- a/test/features.jl +++ b/test/features.jl @@ -570,6 +570,11 @@ end @test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],) @test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],) + # with Bool + @test gradient(x -> sum(1 .- (x .> 0)), randn(5)) == (nothing,) + @test gradient(x -> sum((y->1-y).(x .> 0)), randn(5)) == (nothing,) + @test gradient(x -> sum(x .- (x .> 0)), randn(5)) == ([1,1,1,1,1],) + @test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],) @test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)