-
-
Notifications
You must be signed in to change notification settings - Fork 13
/
adjoint.jl
94 lines (84 loc) · 3.06 KB
/
adjoint.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
using MacroTools
using MacroTools: @q, combinedef
using ChainRulesCore: AbstractZero
function named(arg)
if isexpr(arg, :(::)) && length(arg.args) == 1
:($(gensym())::$(arg.args[1]))
elseif isexpr(arg, :kw)
@assert length(arg.args) == 2
decl, default = arg.args
Expr(:kw, named(decl), default)
else
arg
end
end
typeless(x) = MacroTools.postwalk(x -> isexpr(x, :(::), :kw) ? x.args[1] : x, x)
isvararg(x) = isexpr(x, :(::)) && namify(x.args[2]) == :Vararg
for n = 0:3
gradtuple = Symbol(:gradtuple, n)
@eval begin
$gradtuple(x::Tuple) = ($(ntuple(_->:nothing,n)...), x...)
$gradtuple(::Tuple{Vararg{Nothing}}) = nothing
$gradtuple(x::Nothing) = nothing
$gradtuple(x) = error("Gradient $x should be a tuple")
end
end
abstract type AContext end
function adjoint end
function _pullback end
function pullback end
function unthunk_tangent end
@inline unthunk_tangent(x) = x
@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x)
@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x)
function gradm(ex, mut = false, keepthunks = false)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {Ts__} = body_)) || error("Need a function definition")
kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing
isclosure = isexpr(name, :(::)) && length(name.args) > 1
f, T = isexpr(name, :(::)) ?
(length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) :
(esc(gensym()), :(Core.Typeof($(esc(name)))))
kT = :(Core.kwftype($T))
Ts == nothing && (Ts = [])
args = named.(args)
argnames = Any[typeless(arg) for arg in args]
!isempty(args) && isvararg(args[end]) && (argnames[end] = :($(argnames[end])...,))
args = esc.(args)
argnames = esc.(argnames)
Ts = esc.(Ts)
cx = :($(esc(:__context__))::AContext)
fargs = kw == nothing ? [cx, :($f::$T), args...] : [kw, cx, :($f::$T), args...]
gradtuple = isclosure ? gradtuple0 : gradtuple1
gradtuplekw = isclosure ? gradtuple2 : gradtuple3
adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body))
maybe_unthunked_Δ = keepthunks ? :Δ : :(unthunk_tangent(Δ))
quote
$adj
@inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...))
$(mut ? nothing : quote
back(::Nothing) = nothing
back(Δ::AbstractZero) = $gradtuple(ntuple(_ -> Δ, $(length(args))))
end)
back(Δ) = $gradtuple(_back($maybe_unthunked_Δ))
return y, back
end
@inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...); kw...)
$(mut ? nothing : quote
back(::Nothing) = nothing
back(Δ::AbstractZero) = $gradtuplekw(ntuple(_ -> Δ, $(length(args))))
end)
back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ))
return y, back
end
nothing
end
end
macro adjoint(ex)
gradm(ex, false, false)
end
macro adjoint!(ex)
gradm(ex, true, false)
end