Skip to content

Commit bc8ac2c

Browse files
committed
use metaprogramming in blacklist
Update src/compiler/interface2.jl fix missing eval
1 parent 74396b4 commit bc8ac2c

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

src/compiler/interface2.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,18 @@ chainrules_blacklist(f, args...) = false
3333

3434
# ChainRules does higher-order functions badly
3535
# see https://github.com/JuliaDiff/ChainRules.jl/issues/122
36-
chainrules_blacklist(::typeof(map), args...) = true
37-
chainrules_blacklist(::typeof(broadcast), args...) = true
38-
chainrules_blacklist(::typeof(mapreduce), args...) = true
39-
chainrules_blacklist(::typeof(mapfoldl), args...) = true
40-
chainrules_blacklist(::typeof(mapfoldr), args...) = true
36+
for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr)
37+
@eval chainrules_blacklist(::typeof($f), args...) = true
38+
end
4139
chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true
4240
# Except for sum(abs2, xs), that is fine
4341
chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false
4442

4543
# ChainRules current Wirtinger deriviative is not compatible
4644
# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
47-
chainrules_blacklist(::typeof(abs), ::Complex) = true
48-
chainrules_blacklist(::typeof(abs2), ::Complex) = true
49-
chainrules_blacklist(::typeof(conj), ::Complex) = true
50-
chainrules_blacklist(::typeof(adjoint), ::Complex) = true
51-
chainrules_blacklist(::typeof(hypot), ::Complex) = true
52-
chainrules_blacklist(::typeof(angle), ::Complex) = true
53-
chainrules_blacklist(::typeof(imag), ::Complex) = true
54-
chainrules_blacklist(::typeof(real), ::Complex) = true
45+
for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real)
46+
@eval chainrules_blacklist(::typeof($f), ::Complex) = true
47+
end
5548

5649
# Sum of nonarrays doesn't really work
5750
# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
@@ -66,8 +59,8 @@ Converts a ChainRules pullback into a Zygote pullback.
6659
`pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
6760
"""==#
6861
function _pullback_via_chainrules(pb)
69-
# The less optimized version of this code is
70-
# cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...))
62+
# This is the optimized version of
63+
# _pullback_via_chainrules(pb) = (Δs...) -> zextern.(pb(Δs...))
7164
function zback(Δs...)
7265
∂s = pb(Δs...)
7366
ntuple(length(∂s)) do ii

0 commit comments

Comments
 (0)