Skip to content

Update of #36 and #35 #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.5.3"
version = "0.6.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ This operation goes by a few names. Refer to the [ChainRules documentation](http

The following functions can be used to request the pullback operator/function with or without the function value. In order to request the pullback function `pb_f` of a function `f` at the inputs `xs`, you can use either of:
- `pb_f = AD.pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns the pullback function `pb_f` of the function `f` at the inputs `xs`. `pb_f` is a function that accepts the co-tangents `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can also accept a single input instead of a 1-tuple.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns a function `value_and_pb_f` which accepts the co-tangent `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `value_and_pb_f` can accept a single input instead of a 1-tuple. `value_and_pb_f` returns a 2-tuple, namely the value `f(xs...)` and output of the pullback operator.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: computes the function value `v = f(xs...)` and returns a 2-tuple containing the value `v` and a function `pb_f` that accepts the co-tangent `vs` as input, which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can accept a single input instead of a 1-tuple.

### Lazy operators

Expand Down
16 changes: 11 additions & 5 deletions ext/AbstractDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ module AbstractDifferentiationChainRulesCoreExt
import AbstractDifferentiation as AD
using ChainRulesCore: ChainRulesCore

AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
_, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
pullback(vs) = Base.tail(back(vs))
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
return pullback
AD.@primitive function value_and_pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
value, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
function rrule_pullback(vs)
_vs = if vs isa Tuple && !(value isa Tuple)
only(vs)
else
vs
end
return Base.tail(back(_vs))
end
return value, rrule_pullback
end

end # module
13 changes: 13 additions & 0 deletions ext/AbstractDifferentiationFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ function AD.jacobian(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.jacobian(ba.method, f, xs...)
end

function AD.gradient(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.grad(ba.method, f, xs...)
end

function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...)
return function pushforward(vs)
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
Expand All @@ -32,6 +36,15 @@ function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
end
end

# Ensure consistency with `value_and_pullback` function
function AD.value_and_pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
value = f(xs...)
function fd_pullback(vs)
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
end
return value, fd_pullback
end

# Better performance: issue #87
function AD.derivative(ba::AD.FiniteDifferencesBackend, f::TF, x::Real) where {TF<:Function}
return (ba.method(f, x),)
Expand Down
16 changes: 9 additions & 7 deletions ext/AbstractDifferentiationTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)

AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
value, back = Tracker.forward(f, xs...)
function pullback(ws)
if ws isa Tuple && !(value isa Tuple)
map(Tracker.data, back(only(ws)))
AD.@primitive function value_and_pullback_function(ba::AD.TrackerBackend, f, xs...)
_value, back = Tracker.forward(f, xs...)
value = map(Tracker.data, _value)
function tracker_pullback(ws)
_ws = if ws isa Tuple && !(value isa Tuple)
only(ws)
else
map(Tracker.data, back(ws))
ws
end
return map(Tracker.data, back(_ws))
end
return pullback
return value, tracker_pullback
end

function AD.derivative(::AD.TrackerBackend, f, xs::Number...)
Expand Down
65 changes: 25 additions & 40 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,48 +221,28 @@ end
end

function pullback_function(ab::AbstractBackend, f, xs...)
return (ws) -> begin
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
end, xs...)
end
_, pbf = value_and_pullback_function(ab, f, xs...)
return pbf
end
function value_and_pullback_function(
ab::AbstractBackend,
f,
xs...,
)
return (ws) -> begin
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, xs)
primalcalled = true
end
if ws === nothing
vs = f(xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
end
return value, nothing
end
pb = pullback_function(lowest(ab), (_xs...,) -> begin
value = f(xs...)
function pullback_function(ws)
function pullback_gradient_function(_xs...)
vs = f(_xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
return vs
end, xs...)(ws)
return value, pb
end
return gradient(lowest(ab), pullback_gradient_function, xs...)
end
return value, pullback_function
end

struct LazyDerivative{B, F, X}
Expand Down Expand Up @@ -494,6 +474,12 @@ macro primitive(expr)
name = fdef[:name]
if name == :pushforward_function
return define_pushforward_function_and_friends(fdef) |> esc
elseif name == :value_and_pullback_function
return define_value_and_pullback_function_and_friends(fdef) |> esc
elseif name == :jacobian
return define_jacobian_and_friends(fdef) |> esc
elseif name == :primal_value
return define_primal_value(fdef) |> esc
elseif name == :pullback_function
return define_pullback_function_and_friends(fdef) |> esc
else
Expand Down Expand Up @@ -537,30 +523,29 @@ function define_pushforward_function_and_friends(fdef)
return funcs
end

function define_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).pullback_function)
function define_value_and_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).value_and_pullback_function)
args = fdef[:args]
funcs = quote
$(ExprTools.combinedef(fdef))
function $(AbstractDifferentiation).jacobian($(args...),)
value_and_pbf = $(value_and_pullback_function)($(args...),)
value, _ = value_and_pbf(nothing)
value, pbf = $(value_and_pullback_function)($(args...),)
identity_like = $(identity_matrix_like)(value)
if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}}
return map(identity_like) do identity_like_i
return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...)
value_and_pbf(cols)[2]'
pbf(cols)'
end
end
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1], dims=1) do cols
adjoint.(value_and_pbf((cols,))[2])
adjoint.(pbf((cols,)))
end ...)
else
return adjoint.(value_and_pbf(identity_like)[2])
return adjoint.(pbf(identity_like))
end
end
end
Expand Down
29 changes: 13 additions & 16 deletions test/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ struct FDMBackend3{A} <: AD.AbstractFiniteDifference
end
FDMBackend3() = FDMBackend3(central_fdm(5, 1))
const fdm_backend3 = FDMBackend3()
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
return function (vs)
AD.@primitive function value_and_pullback_function(ab::FDMBackend3, f, xs...)
value = f(xs...)
function fd3_pullback(vs)
# Supports only single output
if vs isa AbstractVector
return FDM.j′vp(ab.alg, f, vs, xs...)
else
return FDM.j′vp(ab.alg, f, only(vs), xs...)
end
_vs = vs isa AbstractVector ? vs : only(vs)
return FDM.j′vp(ab.alg, f, _vs, xs...)
end
return value, fd3_pullback
end
##

Expand Down Expand Up @@ -90,16 +89,14 @@ AD.primal_value(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs..
## Zygote
struct ZygoteBackend1 <: AD.AbstractReverseMode end
const zygote_backend1 = ZygoteBackend1()
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
return function (vs)
# Supports only single output
_, back = Zygote.pullback(f, xs...)
if vs isa AbstractVector
back(vs)
else
back(only(vs))
end
AD.@primitive function value_and_pullback_function(ab::ZygoteBackend1, f, xs...)
# Supports only single output
value, back = Zygote.pullback(f, xs...)
function zygote_pullback(vs)
_vs = vs isa AbstractVector ? vs : only(vs)
return back(_vs)
end
return value, zygote_pullback
end

@testset "defaults" begin
Expand Down
13 changes: 13 additions & 0 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractDifferentiation
using ChainRulesCore
using Test
using Zygote

Expand Down Expand Up @@ -52,4 +53,16 @@ using Zygote
end
@test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
end

# issue #57
@testset "primal computation in rrule" begin
function myfunc(x)
@info "This should not be logged if I have an rrule"
x
end
ChainRulesCore.rrule(::typeof(myfunc), x) = (x, (y -> (NoTangent(), y)))

@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
end
end
9 changes: 6 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
w = rand(rng, length(fjac(xvec, yvec)))
if multiple_inputs
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pbf2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)
pb2 = pbf2(w)

if test_types
@test valvec isa Vector{Float64}
Expand All @@ -263,8 +264,10 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
@test yvec == yvec2
end

valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
valvec1, pbf1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)
pb1 = pbf1(w)
valvec2, pbf2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)
pb2 = pbf2(w)
if test_types
@test valvec1 isa Vector{Float64}
@test valvec2 isa Vector{Float64}
Expand Down