@@ -33,25 +33,18 @@ chainrules_blacklist(f, args...) = false
3333
3434# ChainRules does higher-order functions badly
3535# see https://github.com/JuliaDiff/ChainRules.jl/issues/122
36- chainrules_blacklist (:: typeof (map), args... ) = true
37- chainrules_blacklist (:: typeof (broadcast), args... ) = true
38- chainrules_blacklist (:: typeof (mapreduce), args... ) = true
39- chainrules_blacklist (:: typeof (mapfoldl), args... ) = true
40- chainrules_blacklist (:: typeof (mapfoldr), args... ) = true
36+ for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr)
37+ @eval chainrules_blacklist (:: typeof ($ f), args... ) = true
38+ end
4139chainrules_blacklist (:: typeof (sum), f, x:: AbstractArray{<:Real} ) = true
4240# Except for sum(abs2, xs), that is fine
4341chainrules_blacklist (:: typeof (sum), :: typeof (abs2), x:: AbstractArray{<:Real} ) = false
4442
4543# ChainRules current Wirtinger deriviative is not compatible
4644# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
47- chainrules_blacklist (:: typeof (abs), :: Complex ) = true
48- chainrules_blacklist (:: typeof (abs2), :: Complex ) = true
49- chainrules_blacklist (:: typeof (conj), :: Complex ) = true
50- chainrules_blacklist (:: typeof (adjoint), :: Complex ) = true
51- chainrules_blacklist (:: typeof (hypot), :: Complex ) = true
52- chainrules_blacklist (:: typeof (angle), :: Complex ) = true
53- chainrules_blacklist (:: typeof (imag), :: Complex ) = true
54- chainrules_blacklist (:: typeof (real), :: Complex ) = true
45+ for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real)
46+ @eval chainrules_blacklist (:: typeof ($ f), :: Complex ) = true
47+ end
5548
5649# Sum of nonarrays doesn't really work
5750# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
@@ -66,8 +59,8 @@ Converts a ChainRules pullback into a Zygote pullback.
6659`pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
6760"""==#
6861function _pullback_via_chainrules (pb)
69- # The less optimized version of this code is
70- # cback2zback (pb) = (Δs...) -> zextern.(pb(Δs...))
62+ # This is the optimized version of
63+ # _pullback_via_chainrules (pb) = (Δs...) -> zextern.(pb(Δs...))
7164 function zback (Δs... )
7265 ∂s = pb (Δs... )
7366 ntuple (length (∂s)) do ii
0 commit comments