@@ -3,7 +3,86 @@ using IRTools.Inner: argnames!, update!
33
44ignore (T) = all (T -> T <: Type , T. parameters)
55
6- @generated function _pullback (ctx:: AContext , f, args... )
6+
7+ function _pullback (ctx:: AContext , f, args... )
8+ if chainrules_blacklist (f, args... )
9+ # then don't even consider using ChainRules
10+ return _pullback_via_source2source (ctx, f, args... )
11+ end
12+
13+ res = ChainRules. rrule (f, args... )
14+ if res === nothing
15+ # No ChainRule defined, time to do the source tranform
16+ return _pullback_via_source2source (ctx, f, args... )
17+ else
18+ # Can just use ChainRule answer
19+ y, pb = res
20+ return y, _pullback_via_chainrules (pb)
21+ end
22+ end
23+
24+ #= ="""
25+ chainrules_blacklist(f, args...,)
26+
27+ This is used to disable the use of ChainRule's definitions
28+ for particular functions/methods.
29+
30+ It is not required if a Zygote rule has already been defined directly.
31+ """==#
32+ chainrules_blacklist (f, args... ) = false
33+
34+ # ChainRules does higher-order functions badly
35+ # see https://github.com/JuliaDiff/ChainRules.jl/issues/122
36+ chainrules_blacklist (:: typeof (map), args... ) = true
37+ chainrules_blacklist (:: typeof (broadcast), args... ) = true
38+ chainrules_blacklist (:: typeof (mapreduce), args... ) = true
39+ chainrules_blacklist (:: typeof (mapfoldl), args... ) = true
40+ chainrules_blacklist (:: typeof (mapfoldr), args... ) = true
41+ chainrules_blacklist (:: typeof (sum), f, x:: AbstractArray{<:Real} ) = true
42+ # Except for sum(abs2, xs), that is fine
43+ chainrules_blacklist (:: typeof (sum), :: typeof (abs2), x:: AbstractArray{<:Real} ) = false
44+
45+ # ChainRules current Wirtinger deriviative is not compatible
46+ # reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
47+ chainrules_blacklist (:: typeof (abs), :: Complex ) = true
48+ chainrules_blacklist (:: typeof (abs2), :: Complex ) = true
49+ chainrules_blacklist (:: typeof (conj), :: Complex ) = true
50+ chainrules_blacklist (:: typeof (adjoint), :: Complex ) = true
51+ chainrules_blacklist (:: typeof (hypot), :: Complex ) = true
52+ chainrules_blacklist (:: typeof (angle), :: Complex ) = true
53+ chainrules_blacklist (:: typeof (imag), :: Complex ) = true
54+ chainrules_blacklist (:: typeof (real), :: Complex ) = true
55+
56+ # Sum of nonarrays doesn't really work
57+ # Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
58+ chainrules_blacklist (:: typeof (sum), x) = true
59+ chainrules_blacklist (:: typeof (sum), x:: AbstractArray{<:Real} ) = false
60+
61+
62+ #= ="""
63+ _pullback_via_chainrules(pb)
64+
65+ Converts a ChainRules pullback into a Zygote pullback.
66+ `pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
67+ """==#
68+ function _pullback_via_chainrules (pb)
69+ # The less optimized version of this code is
70+ # cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...))
71+ function zback (Δs... )
72+ ∂s = pb (Δs... )
73+ ntuple (length (∂s)) do ii
74+ ∂ = ∂s[ii]
75+ zextern (∂)
76+ end
77+ end
78+ end
79+
80+ zextern (x) = ChainRules. extern (x)
81+ zextern (:: ChainRules.Zero ) = nothing # Zygote loves calling things nothing
82+ zextern (:: ChainRules.DNE ) = nothing # Zygote loves calling things nothing
83+
84+
85+ @generated function _pullback_via_source2source (ctx:: AContext , f, args... )
786 T = Tuple{f,args... }
887 ignore (T) && return :(f (args... ), Pullback {$T} (()))
988 g = try _lookup_grad (T) catch e e end
0 commit comments