Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub committed Sep 10, 2021
1 parent 2059a35 commit 71d9202
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 46 deletions.
14 changes: 7 additions & 7 deletions src/runtime.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using ChainRulesCore

@Base.aggressive_constprop accum(a, b) = a + b
@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple)
@Base.constprop aggressive accum(a, b) = a + b
@Base.constprop aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.constprop aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
fnames = union(fieldnames(x), fieldnames(y))
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
end
@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.aggressive_constprop accum(a::NoTangent, b) = b
@Base.aggressive_constprop accum(a, b::NoTangent) = a
@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent()
@Base.constprop aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.constprop aggressive accum(a::NoTangent, b) = b
@Base.constprop aggressive accum(a, b::NoTangent) = a
@Base.constprop aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
14 changes: 7 additions & 7 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,37 @@ end
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)

# Special case rules for performance
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TangentBundle{N}(getfield(primal(x), s),
map(x->lifted_getfield(x, s), x.partials))
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TaylorBundle{N}(getfield(primal(x), s),
map(y->lifted_getfield(y, s), x.coeffs))
end

@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
@Base.constprop aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
x.tup[primal(s)]
end

@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
@Base.constprop aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
x.tup[Base.fieldindex(B, primal(s))]
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
s = primal(s)
TangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
map(x->lifted_getfield(x, s), x.partials))
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial)
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial)
end

Expand Down
64 changes: 32 additions & 32 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ struct Protected{N}
a
end
(p::Protected)(args...) = getfield(p, :a)(args...)[1]
@Base.aggressive_constprop (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...)
@Base.aggressive_constprop (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...)
@Base.constprop aggressive (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...)
@Base.constprop aggressive (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...)
(::∂⃖{N})(p::Protected, args...) where {N} = error("TODO: Can we support this?")

struct OpticBundle{T}
Expand Down Expand Up @@ -94,30 +94,30 @@ end
end

struct ∂⃖weaveInnerOdd{N, O}; b̄; end
@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N}
@Base.constprop aggressive function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N}
@destruct c, c̄ = w....)
return (c̄, c)
end
@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
@Base.constprop aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O}
@destruct c, c̄ = w....)
return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}()
end
struct ∂⃖weaveInnerEven{N, O}; end
@Base.aggressive_constprop function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
@Base.constprop aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O}
@destruct y, ȳ = Δ′(x...)
return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ)
end

struct ∂⃖weaveOuterOdd{N, O}; end
@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N}
@Base.constprop aggressive function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N}
return (NoTangent(), Δ′′′(Δ′′)...)
end
@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
@Base.constprop aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O}
@destruct α, ᾱ = Δ′′′(Δ′′)
return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ)
end
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
@Base.aggressive_constprop function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
@Base.constprop aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O}
return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}()
end

Expand Down Expand Up @@ -156,33 +156,33 @@ struct ∂⃖rruleB{N, O}; ᾱ; ȳ̄ ; end
struct ∂⃖rruleC{N, O}; ȳ̄ ; Δ′′′; β̄ ; end
struct ∂⃖rruleD{N, O}; γ̄; β̄ ; end

@Base.aggressive_constprop function (a::∂⃖rruleA{N, O})(Δ) where {N, O}
@Base.constprop aggressive function (a::∂⃖rruleA{N, O})(Δ) where {N, O}
# TODO: Is this unthunk in the right place
@destruct (α, ᾱ) = a.(a.ȳ, unthunk(Δ))
(α, ∂⃖rruleB{N, O}(ᾱ, a.ȳ̄))
end

@Base.aggressive_constprop function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O}
@Base.constprop aggressive function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O}
@destruct ((Δ′′′, β), β̄) = b.ᾱ(Δ′)
(β, ∂⃖rruleC{N, O}(b.ȳ̄, Δ′′′, β̄))
end

@Base.aggressive_constprop function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O}
@Base.constprop aggressive function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O}
@destruct (γ, γ̄) = c.ȳ̄((Δ′′, c.Δ′′′))
(Base.tail(γ), ∂⃖rruleD{N, O}(γ̄, c.β̄))
end

@Base.aggressive_constprop function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O}
@Base.constprop aggressive function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O}
(δ₁, δ₂), δ̄ = d.γ̄(ZeroTangent(), Δ⁴...)
(δ₁, ∂⃖rruleA{N, O+1}(d.β̄ , δ₂, δ̄ ))
end

# Terminal cases
@Base.aggressive_constprop function (c::∂⃖rruleB{N, N})(Δ′...) where {N}
@Base.constprop aggressive function (c::∂⃖rruleB{N, N})(Δ′...) where {N}
@destruct (Δ′′′, β) = c.ᾱ(Δ′)
(β, ∂⃖rruleC{N, N}(c.ȳ̄, Δ′′′, nothing))
end
@Base.aggressive_constprop (c::∂⃖rruleC{N, N})(Δ′′) where {N} =
@Base.constprop aggressive (c::∂⃖rruleC{N, N})(Δ′′) where {N} =
Base.tail(c.ȳ̄((Δ′′, c.Δ′′′)))
(::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached")

Expand Down Expand Up @@ -255,17 +255,17 @@ function ChainRulesCore.rrule(::KwFunc, kwargs, f, args...)
end
end

@Base.aggressive_constprop function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol)
@Base.constprop aggressive function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol)
getfield(s, field), let P = typeof(s)
@Base.aggressive_constprop Δ->begin
@Base.constprop aggressive Δ->begin
nt = NamedTuple{(field,)}((Δ,))
(NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent())
end
end
end

struct ∂⃖getfield{n, f}; end
@Base.aggressive_constprop function (::∂⃖getfield{n, f})(Δ) where {n,f}
@Base.constprop aggressive function (::∂⃖getfield{n, f})(Δ) where {n,f}
if @generated
return Expr(:call, tuple, NoTangent(),
Expr(:call, tuple, (i == f ? :(Δ) : ZeroTangent() for i = 1:n)...),
Expand All @@ -279,31 +279,31 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g)
struct EvenOddOdd{O, P, F, G}; f::F; g::G; end
EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g)
@Base.aggressive_constprop (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
@Base.aggressive_constprop (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
@Base.aggressive_constprop (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)
@Base.constprop aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g))
@Base.constprop aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g))
@Base.constprop aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ)


@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N}
@Base.constprop aggressive function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N}
getfield(s, field), EvenOddOdd{1, c_order(N)}(
∂⃖getfield{nfields(s), field}(),
@Base.aggressive_constprop (_, Δ, _)->getfield(Δ, field))
@Base.constprop aggressive (_, Δ, _)->getfield(Δ, field))
end

@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N}
@Base.constprop aggressive function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N}
getfield(s, field), EvenOddOdd{1, c_order(N)}(
∂⃖getfield{nfields(s), field}(),
@Base.aggressive_constprop (_, Δ, _)->lifted_getfield(Δ, field))
@Base.constprop aggressive (_, Δ, _)->lifted_getfield(Δ, field))
end

function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
getfield(s, field), let P = typeof(s)
EvenOddOdd{1, c_order(N)}(
(@Base.aggressive_constprop Δ->begin
(@Base.constprop aggressive Δ->begin
nt = NamedTuple{(field,)}((Δ,))
(NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent())
end),
(@Base.aggressive_constprop (_, Δs, _)->begin
(@Base.constprop aggressive (_, Δs, _)->begin
isa(Δs, Union{ZeroTangent, NoTangent}) ? Δs : getfield(ChainRulesCore.backing(Δs), field)
end))
end
Expand All @@ -313,13 +313,13 @@ end
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
getindex(a, inds...), let
EvenOddOdd{1, c_order(N)}(
(@Base.aggressive_constprop Δ->begin
(@Base.constprop aggressive Δ->begin
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
BB = zero(a)
BB[inds...] = Δ
(NoTangent(), BB, map(x->NoTangent(), inds)...)
end),
(@Base.aggressive_constprop (_, Δ, _)->begin
(@Base.constprop aggressive (_, Δ, _)->begin
getindex(Δ, inds...)
end))
end
Expand Down Expand Up @@ -355,15 +355,15 @@ end

struct ApplyOdd{O, P}; u; ∂⃖f; end
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
@Base.aggressive_constprop function (a::ApplyOdd{O, P})(Δ) where {O, P}
@Base.constprop aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P}
r, ∂⃖∂⃖f = a.∂⃖f(Δ)
(a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f))
end
@Base.aggressive_constprop function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
@Base.constprop aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P}
r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...)
(r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f))
end
@Base.aggressive_constprop function (a::ApplyOdd{O, O})(Δ) where {O}
@Base.constprop aggressive function (a::ApplyOdd{O, O})(Δ) where {O}
r = a.∂⃖f(Δ)
a.u(r)
end
Expand All @@ -381,7 +381,7 @@ end
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}()
end

@Base.aggressive_constprop lifted_getfield(x, s) = getfield(x, s)
@Base.constprop aggressive lifted_getfield(x, s) = getfield(x, s)
lifted_getfield(x::ZeroTangent, s) = ZeroTangent()
lifted_getfield(x::NoTangent, s) = NoTangent()

Expand Down

0 comments on commit 71d9202

Please sign in to comment.