You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.
In order to make Nabla use ChainRules for sensitivities with full feature parity with the current implementation, we'll need to port the sensitivity definitions here (i.e. the ∇ methods) and turn them into rrule methods in ChainRules.
How to port
Porting them over is actually pretty straightforward. Nabla's ∇ methods pass an Arg{i} argument that dictates which of the arguments to the function is being differentiated in the current method. For example,
f(a::Int, b::Int) = a +2b +1# Derivative for the first argument, `a`∇(::typeof(f), ::Type{Arg{1}}, p, y, ȳ, a::Int, b::Int) = ȳ
# Derivative for the second, `b`∇(::typeof(f), ::Type{Arg{2}}, p, y, ȳ, a::Int, b::Int) =2ȳ
ChainRules rrules methods include both derivatives in a single method. So the above translates to
functionrrule(::typeof(f), a::Int, b::Int)
y =f(a, b)
∂a =Rule(ȳ -> ȳ)
∂b =Rule(ȳ ->2ȳ)
return y, (∂a, ∂b)
end
There are cases where a ∇ is purposefully not defined for a given Arg{i}; that denotes that there is no derivative with respect to that argument. In ChainRules, we express that by returning a DNERule() in place of the Rule. So if in the above example f was only differentiable with respect to b, the rrule would instead look like
functionrrule(::typeof(f), a::Int, b::Int)
y =f(a, b)
∂b =Rule(ȳ ->2ȳ)
return y, (DNERule(), ∂b)
end
Also note that the derivatives for the various arguments can share intermediate computation. That can go into the body of the rrule method itself, with the defined variables captured in the closures in the Rules.
There are some cases where Nabla defines custom methods for updating the tape with a given sensitivity. Those are expressed as methods of ∇ with the tape value x̄ as the first argument. ChainRules does this differently: if you have a special way in which you'd like to accumulate a sensitivity to a given value, you provide a second argument to Rule that's another function that takes arguments (x̄, ȳ). This is used by the ChainRules.accumulate!(value, rule, args...) method. Just like Nabla, if no such special method exists for updating, a generic fallback is used.
Progress
Below is a list of all of the basic ∇ methods. More items are finished than are currently checked as of this writing; as you find that ChainRules does indeed include a corresponding rrule method and the sensitivity definition it uses looks correct, please check these off.
Note that this list does not include methods which update the tape directly!
*(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
*(::Adjoint, ::Adjoint)
*(::Adjoint, ::StridedMatrix)
*(::Number, ::Number)
*(::StridedMatrix, ::Adjoint)
*(::StridedMatrix, ::StridedMatrix)
*(::StridedMatrix, ::Transpose)
*(::Transpose, ::StridedMatrix)
*(::Transpose, ::Transpose)
*(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
+(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
+(::AbstractArray{#s12<:Number,N} where N, ::UniformScaling{T<:Number})
+(::Number)
+(::Number, ::Number)
+(::UniformScaling{T<:Number}, ::AbstractArray{#s12<:Number,N} where N)
-(::AbstractArray{#s12<:Number,N} where N)
-(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
-(::Number)
-(::Number, ::Number)
/(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
/(::Number, ::Number)
/(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
In order to make Nabla use ChainRules for sensitivities with full feature parity with the current implementation, we'll need to port the sensitivity definitions here (i.e. the
∇
methods) and turn them intorrule
methods in ChainRules.How to port
Porting them over is actually pretty straightforward. Nabla's
∇
methods pass anArg{i}
argument that dictates which of the arguments to the function is being differentiated in the current method. For example,ChainRules
rrule
s methods include both derivatives in a single method. So the above translates toThere are cases where a
∇
is purposefully not defined for a givenArg{i}
; that denotes that there is no derivative with respect to that argument. In ChainRules, we express that by returning aDNERule()
in place of theRule
. So if in the above examplef
was only differentiable with respect tob
, therrule
would instead look likeAlso note that the derivatives for the various arguments can share intermediate computation. That can go into the body of the
rrule
method itself, with the defined variables captured in the closures in theRule
s.There are some cases where Nabla defines custom methods for updating the tape with a given sensitivity. Those are expressed as methods of
∇
with the tape valuex̄
as the first argument. ChainRules does this differently: if you have a special way in which you'd like to accumulate a sensitivity to a given value, you provide a second argument toRule
that's another function that takes arguments(x̄, ȳ)
. This is used by theChainRules.accumulate!(value, rule, args...)
method. Just like Nabla, if no such special method exists for updating, a generic fallback is used.Progress
Below is a list of all of the basic
∇
methods. More items are finished than are currently checked as of this writing; as you find that ChainRules does indeed include a correspondingrrule
method and the sensitivity definition it uses looks correct, please check these off.Note that this list does not include methods which update the tape directly!
*(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
*(::Adjoint, ::Adjoint)
*(::Adjoint, ::StridedMatrix)
*(::Number, ::Number)
*(::StridedMatrix, ::Adjoint)
*(::StridedMatrix, ::StridedMatrix)
*(::StridedMatrix, ::Transpose)
*(::Transpose, ::StridedMatrix)
*(::Transpose, ::Transpose)
*(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
+(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
+(::AbstractArray{#s12<:Number,N} where N, ::UniformScaling{T<:Number})
+(::Number)
+(::Number, ::Number)
+(::UniformScaling{T<:Number}, ::AbstractArray{#s12<:Number,N} where N)
-(::AbstractArray{#s12<:Number,N} where N)
-(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
-(::Number)
-(::Number, ::Number)
/(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
/(::Number, ::Number)
/(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
Cholesky(::Union{LowerTriangular, UpperTriangular}, ::Union{Char, Symbol}, ::Integer)
Diagonal(::AbstractArray{#s12<:Number,1})
Diagonal(::AbstractArray{#s12<:Number,2})
LinearAlgebra.BLAS.asum(::Any)
LinearAlgebra.BLAS.asum(::Integer, ::Any, ::Integer)
LinearAlgebra.BLAS.dot(::Int64, ::StridedVector, ::Int64, ::StridedVector, ::Int64)
LinearAlgebra.BLAS.gemm(::Char, ::Char, ::StridedMatrix, ::StridedMatrix)
(Implement sensitivities forBLAS.gemm
JuliaDiff/ChainRules.jl#25)LinearAlgebra.BLAS.gemm(::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedMatrix)
(Implement sensitivities forBLAS.gemm
JuliaDiff/ChainRules.jl#25)LinearAlgebra.BLAS.gemv(::Char, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.gemv(::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.nrm2(::Any)
LinearAlgebra.BLAS.nrm2(::Integer, ::Any, ::Integer)
LinearAlgebra.BLAS.symm(::Char, ::Char, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.symm(::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.symv(::Char, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.symv(::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.trmm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.trmm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.trsm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedMatrix)
LinearAlgebra.BLAS.trsm(::Char, ::Char, ::Char, ::Char, ::T<:Number, ::StridedMatrix, ::StridedVector)
LinearAlgebra.BLAS.trsv(::Char, ::Char, ::Char, ::StridedMatrix, ::StridedVector)
LinearAlgebra.cholesky(::AbstractArray{T<:Number,2})
(Add an rrule for the Cholesky decomposition JuliaDiff/ChainRules.jl#44)LinearAlgebra.det(::AbstractArray{#s12<:Number,N} where N)
LinearAlgebra.det(::Diagonal{#s77<:Number,V} where V<:AbstractArray{#s77<:Number,1})
LinearAlgebra.det(::LowerTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
LinearAlgebra.det(::UpperTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
LinearAlgebra.diag(::AbstractArray{#s12<:Number,2})
LinearAlgebra.diag(::AbstractArray{#s12<:Number,2}, ::Integer)
LinearAlgebra.dot(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
LinearAlgebra.logdet(::AbstractArray{#s12<:Number,N} where N)
LinearAlgebra.logdet(::Diagonal{#s77<:Number,V} where V<:AbstractArray{#s77<:Number,1})
LinearAlgebra.logdet(::LowerTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
LinearAlgebra.logdet(::UpperTriangular{#s77<:Number,S} where S<:AbstractArray{#s77<:Number,2})
LinearAlgebra.norm(::AbstractArray{#s12<:Number,N} where N)
(Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)LinearAlgebra.norm(::AbstractArray{#s12<:Number,N} where N, ::Number)
(Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)LinearAlgebra.norm(::Number)
(Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)LinearAlgebra.norm(::Number, ::Number)
(Add rrules for binary linear algebra operations JuliaDiff/ChainRules.jl#29)LinearAlgebra.svd(::AbstractArray{T,2})
(Add SVD factorization rrule JuliaDiff/ChainRules.jl#31)LinearAlgebra.tr(::AbstractArray{#s12<:Number,N} where N)
LowerTriangular(::AbstractArray{#s12<:Number,2})
SpecialFunctions.airyai(::Number)
SpecialFunctions.airyaiprime(::Number)
SpecialFunctions.airybi(::Number)
SpecialFunctions.airybiprime(::Number)
SpecialFunctions.besseli(::Number, ::Number)
SpecialFunctions.besselj(::Number, ::Number)
SpecialFunctions.besselj0(::Number)
SpecialFunctions.besselj1(::Number)
SpecialFunctions.besselk(::Number, ::Number)
SpecialFunctions.bessely(::Number, ::Number)
SpecialFunctions.bessely0(::Number)
SpecialFunctions.bessely1(::Number)
SpecialFunctions.beta(::Number, ::Number)
SpecialFunctions.dawson(::Number)
SpecialFunctions.digamma(::Number)
SpecialFunctions.erf(::Number)
SpecialFunctions.erfc(::Number)
SpecialFunctions.erfcinv(::Number)
SpecialFunctions.erfcx(::Number)
SpecialFunctions.erfi(::Number)
SpecialFunctions.erfinv(::Number)
SpecialFunctions.gamma(::Number)
SpecialFunctions.invdigamma(::Number)
SpecialFunctions.lbeta(::Number, ::Number)
SpecialFunctions.lgamma(::Number)
SpecialFunctions.polygamma(::Number, ::Number)
SpecialFunctions.trigamma(::Number)
Statistics.mean(::AbstractArray{#s74<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)Statistics.mean(::Function, ::AbstractArray{#s77<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)UpperTriangular(::AbstractArray{#s12<:Number,2})
\(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
\(::Number, ::Number)
\(::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
^(::Number, ::Number)
abs(::Number)
abs2(::Number)
acos(::Number)
acosd(::Number)
acosh(::Number)
acot(::Number)
acotd(::Number)
acoth(::Number)
acsc(::Number)
acscd(::Number)
acsch(::Number)
adjoint(::AbstractArray{#s12<:Number,N} where N)
adjoint(::Number)
asec(::Number)
asecd(::Number)
asech(::Number)
asin(::Number)
asind(::Number)
asinh(::Number)
atand(::Number)
atanh(::Number)
broadcast(::Any, ::Vararg{Any,N})
cbrt(::Number)
copy(::Any)
cos(::Number)
cosd(::Number)
cosh(::Number)
cospi(::Number)
cot(::Number)
cotd(::Number)
coth(::Number)
csc(::Number)
cscd(::Number)
csch(::Number)
deg2rad(::Number)
exp(::AbstractArray{T,2})
exp(::Number)
exp10(::Number)
exp2(::Number)
expm1(::Number)
fill(::Any, ::Vararg{Any,N})
(Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)float(::Any)
getindex(::Any, ::Vararg{Any,N})
getproperty(::Cholesky{T,S} where S<:(AbstractArray{T,2} where T), ::Symbol)
(Add an rrule for the Cholesky decomposition JuliaDiff/ChainRules.jl#44)getproperty(::SVD{T,Tr,M} where M<:(AbstractArray{T,N} where N) where Tr, ::Symbol)
(Add SVD factorization rrule JuliaDiff/ChainRules.jl#31)hcat(::Vararg{AbstractArray,N})
(Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)hypot(::Number, ::Number)
identity(::Any)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)inv(::AbstractArray{#s12<:Number,N} where N)
inv(::Number)
kron(::AbstractArray{#s12<:Number,N} where N, ::AbstractArray{#s12<:Number,N} where N)
log(::Number)
log10(::Number)
log2(::Number)
map(::Function, ::Vararg{AbstractArray{#s12,N} where N where #s12<:Number,N})
(Add rrule for map and expand the testing framework JuliaDiff/ChainRules.jl#56)mapfoldl(::Any, ::Union{typeof(+), typeof(add_sum)}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)mapfoldr(::Any, ::Union{typeof(+), typeof(add_sum)}, ::Union{Number, AbstractArray{#s12,N} where N where #s12<:Number})
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)mapreduce(::Any, ::Union{typeof(+), typeof(add_sum)}, ::AbstractArray{#s38<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)max(::Number, ::Number)
min(::Number, ::Number)
rad2deg(::Number)
reshape(::AbstractArray{#s12<:Number,N} where N, ::Vararg{Any,N})
(Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)sec(::Number)
secd(::Number)
sech(::Number)
sin(::Number)
sind(::Number)
sinh(::Number)
sinpi(::Number)
sqrt(::Number)
sum(::AbstractArray{#s68<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)sum(::Function, ::AbstractArray{#s64<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)sum(::typeof(abs2), ::AbstractArray{#s73<:Number,N} where N)
(Add a few more reduction rrules JuliaDiff/ChainRules.jl#59)tan(::Number)
tand(::Number)
tanh(::Number)
transpose(::AbstractArray{#s12<:Number,N} where N)
transpose(::Number)
vcat(::Vararg{AbstractArray,N})
(Port more array-related derivatives from Nabla JuliaDiff/ChainRules.jl#45)The text was updated successfully, but these errors were encountered: