diff --git a/src/runtime.jl b/src/runtime.jl index 0b48206d..f334cea6 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -1,14 +1,14 @@ using ChainRulesCore -@Base.aggressive_constprop accum(a, b) = a + b -@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b) -@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple) +@Base.constprop aggressive accum(a, b) = a + b +@Base.constprop aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b) +@Base.constprop aggressive @generated function accum(x::NamedTuple, y::NamedTuple) fnames = union(fieldnames(x), fieldnames(y)) gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent()) grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent()) Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...) end -@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...) -@Base.aggressive_constprop accum(a::NoTangent, b) = b -@Base.aggressive_constprop accum(a, b::NoTangent) = a -@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent() +@Base.constprop aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...) +@Base.constprop aggressive accum(a::NoTangent, b) = b +@Base.constprop aggressive accum(a, b::NoTangent) = a +@Base.constprop aggressive accum(a::NoTangent, b::NoTangent) = NoTangent() diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index f2eda524..050cf31a 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -134,37 +134,37 @@ end (::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...) # Special case rules for performance -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} +@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) TangentBundle{N}(getfield(primal(x), s), map(x->lifted_getfield(x, s), x.partials)) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N} +@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N} s = primal(s) TaylorBundle{N}(getfield(primal(x), s), map(y->lifted_getfield(y, s), x.coeffs)) end -@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} +@Base.constprop aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} x.tup[primal(s)] end -@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} +@Base.constprop aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} x.tup[Base.fieldindex(B, primal(s))] end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N} +@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N} s = primal(s) TangentBundle{N}(getfield(primal(x), s, primal(inbounds)), map(x->lifted_getfield(x, s), x.partials)) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} +@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial) end -@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U} +@Base.constprop aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial) end diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 62547309..64b2c7cd 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -38,8 +38,8 @@ struct Protected{N} a end (p::Protected)(args...) = getfield(p, :a)(args...)[1] -@Base.aggressive_constprop (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...) -@Base.aggressive_constprop (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...) +@Base.constprop aggressive (::∂⃖{N})(p::Protected{N}, args...) where {N} = getfield(p, :a)(args...) +@Base.constprop aggressive (::∂⃖{1})(p::Protected{1}, args...) = getfield(p, :a)(args...) (::∂⃖{N})(p::Protected, args...) where {N} = error("TODO: Can we support this?") struct OpticBundle{T} @@ -94,30 +94,30 @@ end end struct ∂⃖weaveInnerOdd{N, O}; b̄; end -@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N} +@Base.constprop aggressive function (w::∂⃖weaveInnerOdd{N, N})(Δ) where {N} @destruct c, c̄ = w.b̄(Δ...) return (c̄, c) end -@Base.aggressive_constprop function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O} +@Base.constprop aggressive function (w::∂⃖weaveInnerOdd{N, O})(Δ) where {N, O} @destruct c, c̄ = w.b̄(Δ...) return (c̄, c), ∂⃖weaveInnerEven{plus1(N), O}() end struct ∂⃖weaveInnerEven{N, O}; end -@Base.aggressive_constprop function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O} +@Base.constprop aggressive function (w::∂⃖weaveInnerEven{N, O})(Δ′, x...) where {N, O} @destruct y, ȳ = Δ′(x...) return y, ∂⃖weaveInnerOdd{plus1(N), O}(ȳ) end struct ∂⃖weaveOuterOdd{N, O}; end -@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N} +@Base.constprop aggressive function (w::∂⃖weaveOuterOdd{N, N})((Δ′′, Δ′′′)) where {N} return (NoTangent(), Δ′′′(Δ′′)...) end -@Base.aggressive_constprop function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O} +@Base.constprop aggressive function (w::∂⃖weaveOuterOdd{N, O})((Δ′′, Δ′′′)) where {N, O} @destruct α, ᾱ = Δ′′′(Δ′′) return (NoTangent(), α...), ∂⃖weaveOuterEven{plus1(N), O}(ᾱ) end struct ∂⃖weaveOuterEven{N, O}; ᾱ end -@Base.aggressive_constprop function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O} +@Base.constprop aggressive function (w::∂⃖weaveOuterEven{N, O})(Δ⁴...) where {N, O} return w.ᾱ(Base.tail(Δ⁴)...), ∂⃖weaveOuterOdd{plus1(N), O}() end @@ -156,33 +156,33 @@ struct ∂⃖rruleB{N, O}; ᾱ; ȳ̄ ; end struct ∂⃖rruleC{N, O}; ȳ̄ ; Δ′′′; β̄ ; end struct ∂⃖rruleD{N, O}; γ̄; β̄ ; end -@Base.aggressive_constprop function (a::∂⃖rruleA{N, O})(Δ) where {N, O} +@Base.constprop aggressive function (a::∂⃖rruleA{N, O})(Δ) where {N, O} # TODO: Is this unthunk in the right place @destruct (α, ᾱ) = a.∂(a.ȳ, unthunk(Δ)) (α, ∂⃖rruleB{N, O}(ᾱ, a.ȳ̄)) end -@Base.aggressive_constprop function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O} +@Base.constprop aggressive function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O} @destruct ((Δ′′′, β), β̄) = b.ᾱ(Δ′) (β, ∂⃖rruleC{N, O}(b.ȳ̄, Δ′′′, β̄)) end -@Base.aggressive_constprop function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O} +@Base.constprop aggressive function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O} @destruct (γ, γ̄) = c.ȳ̄((Δ′′, c.Δ′′′)) (Base.tail(γ), ∂⃖rruleD{N, O}(γ̄, c.β̄)) end -@Base.aggressive_constprop function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O} +@Base.constprop aggressive function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O} (δ₁, δ₂), δ̄ = d.γ̄(ZeroTangent(), Δ⁴...) (δ₁, ∂⃖rruleA{N, O+1}(d.β̄ , δ₂, δ̄ )) end # Terminal cases -@Base.aggressive_constprop function (c::∂⃖rruleB{N, N})(Δ′...) where {N} +@Base.constprop aggressive function (c::∂⃖rruleB{N, N})(Δ′...) where {N} @destruct (Δ′′′, β) = c.ᾱ(Δ′) (β, ∂⃖rruleC{N, N}(c.ȳ̄, Δ′′′, nothing)) end -@Base.aggressive_constprop (c::∂⃖rruleC{N, N})(Δ′′) where {N} = +@Base.constprop aggressive (c::∂⃖rruleC{N, N})(Δ′′) where {N} = Base.tail(c.ȳ̄((Δ′′, c.Δ′′′))) (::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached") @@ -255,9 +255,9 @@ function ChainRulesCore.rrule(::KwFunc, kwargs, f, args...) end end -@Base.aggressive_constprop function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol) +@Base.constprop aggressive function ChainRulesCore.rrule(::typeof(Core.getfield), s, field::Symbol) getfield(s, field), let P = typeof(s) - @Base.aggressive_constprop Δ->begin + @Base.constprop aggressive Δ->begin nt = NamedTuple{(field,)}((Δ,)) (NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent()) end @@ -265,7 +265,7 @@ end end struct ∂⃖getfield{n, f}; end -@Base.aggressive_constprop function (::∂⃖getfield{n, f})(Δ) where {n,f} +@Base.constprop aggressive function (::∂⃖getfield{n, f})(Δ) where {n,f} if @generated return Expr(:call, tuple, NoTangent(), Expr(:call, tuple, (i == f ? :(Δ) : ZeroTangent() for i = 1:n)...), @@ -279,31 +279,31 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end EvenOddEven{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddEven{O, P, F, G}(f, g) struct EvenOddOdd{O, P, F, G}; f::F; g::G; end EvenOddOdd{O, P}(f::F, g::G) where {O, P, F, G} = EvenOddOdd{O, P, F, G}(f, g) -@Base.aggressive_constprop (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g)) -@Base.aggressive_constprop (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g)) -@Base.aggressive_constprop (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ) +@Base.constprop aggressive (o::EvenOddOdd{O, P, F, G})(Δ) where {O, P, F, G} = (o.f(Δ), EvenOddEven{plus1(O), P, F, G}(o.f, o.g)) +@Base.constprop aggressive (e::EvenOddEven{O, P, F, G})(Δ...) where {O, P, F, G} = (e.g(Δ...), EvenOddOdd{plus1(O), P, F, G}(e.f, e.g)) +@Base.constprop aggressive (o::EvenOddOdd{O, O})(Δ) where {O} = o.f(Δ) -@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N} +@Base.constprop aggressive function (::∂⃖{N})(::typeof(Core.getfield), s, field::Int) where {N} getfield(s, field), EvenOddOdd{1, c_order(N)}( ∂⃖getfield{nfields(s), field}(), - @Base.aggressive_constprop (_, Δ, _)->getfield(Δ, field)) + @Base.constprop aggressive (_, Δ, _)->getfield(Δ, field)) end -@Base.aggressive_constprop function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N} +@Base.constprop aggressive function (::∂⃖{N})(::typeof(Base.getindex), s::Tuple, field::Int) where {N} getfield(s, field), EvenOddOdd{1, c_order(N)}( ∂⃖getfield{nfields(s), field}(), - @Base.aggressive_constprop (_, Δ, _)->lifted_getfield(Δ, field)) + @Base.constprop aggressive (_, Δ, _)->lifted_getfield(Δ, field)) end function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N} getfield(s, field), let P = typeof(s) EvenOddOdd{1, c_order(N)}( - (@Base.aggressive_constprop Δ->begin + (@Base.constprop aggressive Δ->begin nt = NamedTuple{(field,)}((Δ,)) (NoTangent(), Tangent{P, typeof(nt)}(nt), NoTangent()) end), - (@Base.aggressive_constprop (_, Δs, _)->begin + (@Base.constprop aggressive (_, Δs, _)->begin isa(Δs, Union{ZeroTangent, NoTangent}) ? Δs : getfield(ChainRulesCore.backing(Δs), field) end)) end @@ -313,13 +313,13 @@ end function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N} getindex(a, inds...), let EvenOddOdd{1, c_order(N)}( - (@Base.aggressive_constprop Δ->begin + (@Base.constprop aggressive Δ->begin Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...) BB = zero(a) BB[inds...] = Δ (NoTangent(), BB, map(x->NoTangent(), inds)...) end), - (@Base.aggressive_constprop (_, Δ, _)->begin + (@Base.constprop aggressive (_, Δ, _)->begin getindex(Δ, inds...) end)) end @@ -355,15 +355,15 @@ end struct ApplyOdd{O, P}; u; ∂⃖f; end struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end -@Base.aggressive_constprop function (a::ApplyOdd{O, P})(Δ) where {O, P} +@Base.constprop aggressive function (a::ApplyOdd{O, P})(Δ) where {O, P} r, ∂⃖∂⃖f = a.∂⃖f(Δ) (a.u(r), ApplyEven{plus1(O), P}(a.u, ∂⃖∂⃖f)) end -@Base.aggressive_constprop function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P} +@Base.constprop aggressive function (a::ApplyEven{O, P})(_, _, ff, args...) where {O, P} r, ∂⃖∂⃖∂⃖f = Core._apply_iterate(iterate, a.∂⃖∂⃖f, (ff,), args...) (r, ApplyOdd{plus1(O), P}(a.u, ∂⃖∂⃖∂⃖f)) end -@Base.aggressive_constprop function (a::ApplyOdd{O, O})(Δ) where {O} +@Base.constprop aggressive function (a::ApplyOdd{O, O})(Δ) where {O} r = a.∂⃖f(Δ) a.u(r) end @@ -381,7 +381,7 @@ end Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, c_order(N)}() end -@Base.aggressive_constprop lifted_getfield(x, s) = getfield(x, s) +@Base.constprop aggressive lifted_getfield(x, s) = getfield(x, s) lifted_getfield(x::ZeroTangent, s) = ZeroTangent() lifted_getfield(x::NoTangent, s) = NoTangent()