@@ -3,16 +3,9 @@ using IRTools.Inner: argnames!, update!
33
44ignore (T) = all (T -> T <: Type , T. parameters)
55
6-
76function _pullback (ctx:: AContext , f, args... )
8- if chainrules_blacklist (f, args... )
9- # then don't even consider using ChainRules
10- return _pullback_via_source2source (ctx, f, args... )
11- end
12-
13- res = ChainRules. rrule (f, args... )
14- if res === nothing
15- # No ChainRule defined, time to do the source tranform
7+ if chainrules_blacklist (f, args... ) || (res = ChainRules. rrule (f, args... )) === nothing
8+ # Blacklisted or no ChainRule defined, time to do the source tranform
169 return _pullback_via_source2source (ctx, f, args... )
1710 else
1811 # Can just use ChainRule answer
@@ -21,6 +14,39 @@ function _pullback(ctx::AContext, f, args...)
2114 end
2215end
2316
17+ @generated function _pullback_via_source2source (ctx:: AContext , f, args... )
18+ T = Tuple{f,args... }
19+ ignore (T) && return :(f (args... ), Pullback {$T} (()))
20+ g = try _lookup_grad (T) catch e e end
21+ ! (g isa Tuple) && return :(f (args... ), Pullback {$T} ((f,)))
22+ meta, forw, _ = g
23+ argnames! (meta, Symbol (" #self#" ), :ctx , :f , :args )
24+ forw = varargs! (meta, forw, 3 )
25+ # IRTools.verify(forw)
26+ forw = slots! (pis! (inlineable! (forw)))
27+ return update! (meta. code, forw)
28+ end
29+
30+ @generated function (j:: Pullback{T} )(Δ) where T
31+ ignore (T) && return :nothing
32+ g = try _lookup_grad (T)
33+ catch e
34+ rethrow (CompileError (T,e))
35+ end
36+ if g == nothing
37+ Δ == Nothing && return :nothing
38+ return :(error (" Non-differentiable function $(repr (j. t[1 ])) " ))
39+ end
40+ meta, _, back = g
41+ argnames! (meta, Symbol (" #self#" ), :Δ )
42+ # IRTools.verify(back)
43+ back = slots! (inlineable! (back))
44+ return update! (meta. code, back)
45+ end
46+
47+
48+
49+
2450#= ="""
2551 chainrules_blacklist(f, args...,)
2652
7399zextern (x) = ChainRules. extern (x)
74100zextern (:: ChainRules.Zero ) = nothing # Zygote loves calling things nothing
75101zextern (:: ChainRules.DNE ) = nothing # Zygote loves calling things nothing
76-
77-
78- @generated function _pullback_via_source2source (ctx:: AContext , f, args... )
79- T = Tuple{f,args... }
80- ignore (T) && return :(f (args... ), Pullback {$T} (()))
81- g = try _lookup_grad (T) catch e e end
82- ! (g isa Tuple) && return :(f (args... ), Pullback {$T} ((f,)))
83- meta, forw, _ = g
84- argnames! (meta, Symbol (" #self#" ), :ctx , :f , :args )
85- forw = varargs! (meta, forw, 3 )
86- # IRTools.verify(forw)
87- forw = slots! (pis! (inlineable! (forw)))
88- return update! (meta. code, forw)
89- end
90-
91- @generated function (j:: Pullback{T} )(Δ) where T
92- ignore (T) && return :nothing
93- g = try _lookup_grad (T)
94- catch e
95- rethrow (CompileError (T,e))
96- end
97- if g == nothing
98- Δ == Nothing && return :nothing
99- return :(error (" Non-differentiable function $(repr (j. t[1 ])) " ))
100- end
101- meta, _, back = g
102- argnames! (meta, Symbol (" #self#" ), :Δ )
103- # IRTools.verify(back)
104- back = slots! (inlineable! (back))
105- return update! (meta. code, back)
106- end
0 commit comments