Skip to content

RuleConfig and Zygote support #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
ChainRulesCore = "1"
Compat = "3"
ExprTools = "0.1"
ForwardDiff = "0.10"
Expand Down
7 changes: 7 additions & 0 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module AbstractDifferentiation

using LinearAlgebra, ExprTools, Requires, Compat
using ChainRulesCore: RuleConfig, rrule_via_ad

export AD

Expand Down Expand Up @@ -643,11 +644,17 @@ end
@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x

include("ruleconfig.jl")
function __init__()
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl")
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
@static if VERSION >= v"1.6"
ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
end
end
end

end
19 changes: 19 additions & 0 deletions src/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
ReverseRuleConfigBackend

AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package.
"""
struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode
ruleconfig::RC
end

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's reason for moving this in the function body? Isn't it better to obtain back only once outside of the function instead of doing it at every invocation of the pullback function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

if vs isa Tuple && length(vs) === 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this compile away in the same way as, I would assume, if vs isa Tuple{Any}?

return Base.tail(back(vs[1]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this rules out support for functors f as it removes the derivatives wrt to f? But maybe generally the design of AbstractDifferentiation does not support them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is so. AbstractDifferentiation supports a strictly smaller set of applications than, say, Zygote. It's all about scalars and arrays. And even there, I'm pretty sure results will be inconsistent if one tries to, say, take the gradient of a structured array. And functors are not supported.

else
return Base.tail(back(vs))
end
end
end
33 changes: 33 additions & 0 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using AbstractDifferentiation
using Test
using Zygote

@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin
backends = [@inferred(AD.ZygoteBackend())]
@testset for backend in backends
@testset "Derivative" begin
test_derivatives(backend)
end
@testset "Gradient" begin
test_gradients(backend)
end
@testset "Jacobian" begin
test_jacobians(backend)
end
@testset "jvp" begin
test_jvp(backend)
end
@testset "j′vp" begin
test_j′vp(backend)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(backend)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(backend)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(backend)
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ using Test
include("reversediff.jl")
include("finitedifferences.jl")
include("tracker.jl")
@static if VERSION >= v"1.6"
include("ruleconfig.jl")
end
end