From 4f7d5d1aacc7e64b35e5039a78411641daa6a875 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:07:43 -0400 Subject: [PATCH] fix 1086 --- src/lib/broadcast.jl | 4 +++- test/features.jl | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 4e7a3a1cc..3affebd92 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -72,7 +72,9 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) @adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ)) + Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ))) +_minus(Δ) = -Δ +_minus(::Nothing) = nothing @adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x))) diff --git a/test/features.jl b/test/features.jl index d683d0d94..3115a455c 100644 --- a/test/features.jl +++ b/test/features.jl @@ -570,6 +570,11 @@ end @test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],) @test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],) + # with Bool + @test gradient(x -> sum(1 .- (x .> 0)), randn(5)) == (nothing,) + @test gradient(x -> sum((y->1-y).(x .> 0)), randn(5)) == (nothing,) + @test gradient(x -> sum(x .- (x .> 0)), randn(5)) == ([1,1,1,1,1],) + @test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],) @test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)