Skip to content

Commit 951dac6

Browse files
committed
compact choosing pullback mechanism code
1 parent 70a8aff commit 951dac6

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

src/compiler/interface2.jl

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,9 @@ using IRTools.Inner: argnames!, update!
33

44
ignore(T) = all(T -> T <: Type, T.parameters)
55

6-
76
function _pullback(ctx::AContext, f, args...)
8-
if chainrules_blacklist(f, args...)
9-
# then don't even consider using ChainRules
10-
return _pullback_via_source2source(ctx, f, args...)
11-
end
12-
13-
res = ChainRules.rrule(f, args...)
14-
if res === nothing
15-
# No ChainRule defined, time to do the source tranform
7+
if chainrules_blacklist(f, args...) || (res = ChainRules.rrule(f, args...)) === nothing
8+
# Blacklisted or no ChainRule defined, time to do the source tranform
169
return _pullback_via_source2source(ctx, f, args...)
1710
else
1811
# Can just use ChainRule answer
@@ -21,6 +14,39 @@ function _pullback(ctx::AContext, f, args...)
2114
end
2215
end
2316

17+
@generated function _pullback_via_source2source(ctx::AContext, f, args...)
18+
T = Tuple{f,args...}
19+
ignore(T) && return :(f(args...), Pullback{$T}(()))
20+
g = try _lookup_grad(T) catch e e end
21+
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
22+
meta, forw, _ = g
23+
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
24+
forw = varargs!(meta, forw, 3)
25+
# IRTools.verify(forw)
26+
forw = slots!(pis!(inlineable!(forw)))
27+
return update!(meta.code, forw)
28+
end
29+
30+
@generated function (j::Pullback{T})(Δ) where T
31+
ignore(T) && return :nothing
32+
g = try _lookup_grad(T)
33+
catch e
34+
rethrow(CompileError(T,e))
35+
end
36+
if g == nothing
37+
Δ == Nothing && return :nothing
38+
return :(error("Non-differentiable function $(repr(j.t[1]))"))
39+
end
40+
meta, _, back = g
41+
argnames!(meta, Symbol("#self#"), )
42+
# IRTools.verify(back)
43+
back = slots!(inlineable!(back))
44+
return update!(meta.code, back)
45+
end
46+
47+
48+
49+
2450
#=="""
2551
chainrules_blacklist(f, args...,)
2652
@@ -73,34 +99,3 @@ end
7399
zextern(x) = ChainRules.extern(x)
74100
zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing
75101
zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing
76-
77-
78-
@generated function _pullback_via_source2source(ctx::AContext, f, args...)
79-
T = Tuple{f,args...}
80-
ignore(T) && return :(f(args...), Pullback{$T}(()))
81-
g = try _lookup_grad(T) catch e e end
82-
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
83-
meta, forw, _ = g
84-
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
85-
forw = varargs!(meta, forw, 3)
86-
# IRTools.verify(forw)
87-
forw = slots!(pis!(inlineable!(forw)))
88-
return update!(meta.code, forw)
89-
end
90-
91-
@generated function (j::Pullback{T})(Δ) where T
92-
ignore(T) && return :nothing
93-
g = try _lookup_grad(T)
94-
catch e
95-
rethrow(CompileError(T,e))
96-
end
97-
if g == nothing
98-
Δ == Nothing && return :nothing
99-
return :(error("Non-differentiable function $(repr(j.t[1]))"))
100-
end
101-
meta, _, back = g
102-
argnames!(meta, Symbol("#self#"), )
103-
# IRTools.verify(back)
104-
back = slots!(inlineable!(back))
105-
return update!(meta.code, back)
106-
end

0 commit comments

Comments
 (0)