Skip to content

Commit

Permalink
Add rrule for map and expand the testing framework
Browse files Browse the repository at this point in the history
This implements `rrule(map, f, xs...)` and expands `rrule_test` to allow
non-differentiable arguments. For such cases, the user need only pass
`nothing` as the argument's assumed sensitivity.
  • Loading branch information
ararslan committed Jun 18, 2019
1 parent db1eb6b commit 4e7a61c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("rules.jl")
include("rules/base.jl")
include("rules/array.jl")
include("rules/broadcast.jl")
include("rules/mapreduce.jl")
include("rules/linalg/utils.jl")
include("rules/linalg/blas.jl")
include("rules/linalg/dense.jl")
Expand Down
8 changes: 0 additions & 8 deletions src/rules/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ using LinearAlgebra: AbstractTriangular
# these we can use simpler definitions for `/` and `\`.
const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}

#####
##### `sum`
#####

frule(::typeof(sum), x) = (sum(x), Rule(sum))

rrule(::typeof(sum), x) = (sum(x), Rule(cast))

#####
##### `dot`
#####
Expand Down
29 changes: 29 additions & 0 deletions src/rules/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#####
##### `map`
#####

function rrule(::typeof(map), f, xs...)
y = map(f, xs...)
∂xs = ntuple(length(xs)) do i
Rule() do
map(ȳ, xs...) do ȳi, xis...
r = rrule(f, xis...)
if r === nothing
throw(ArgumentError("can't differentiate `map` with `$f`; no `rrule` " *
"is defined for `$f$xis`"))
end
_, ∂xis = r
extern(∂xis[i](ȳi))
end
end
end
return y, (DNERule(), ∂xs...)
end

#####
##### `sum`
#####

frule(::typeof(sum), x) = (sum(x), Rule(sum))

rrule(::typeof(sum), x) = (sum(x), Rule(cast))
11 changes: 11 additions & 0 deletions test/rules/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "Maps and Reductions" begin
@testset "map" begin
rng = MersenneTwister(42)
n = 10
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng, n)
rrule_test(map, ȳ, (sin, nothing), (x, vx))
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRules, Test, FDM, LinearAlgebra, Random
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
DNE, Thunk, Casted
DNE, Thunk, Casted, DNERule
using Base.Broadcast: broadcastable

include("test_util.jl")
Expand All @@ -15,6 +15,7 @@ include("test_util.jl")
@testset "rules" begin
include(joinpath("rules", "base.jl"))
include(joinpath("rules", "array.jl"))
include(joinpath("rules", "mapreduce.jl"))
@testset "linalg" begin
include(joinpath("rules", "linalg", "dense.jl"))
include(joinpath("rules", "linalg", "structured.jl"))
Expand Down
53 changes: 44 additions & 9 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,58 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
test_accumulation(Zero(), dx, ȳ, x̄_ad)
end

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
sig = Expr(:tuple)
call = Expr(:call, f)
newxs = Any[]
arginds = Int[]
i = 1
for (x, ignore) in zip(xs, ignores)
if ignore
push!(call.args, x)
else
push!(call.args, Symbol(:x, i))
push!(sig.args, Symbol(:x, i))
push!(newxs, x)
push!(arginds, i)
end
i += 1
end
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
fd = eval(fdexpr)
fd isa Tuple || (fd = (fd,))
args = Any[nothing for _ in 1:length(xs)]
for (dx, ind) in zip(fd, arginds)
args[ind] = dx
end
return (args...,)
end

function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
# Check correctness of evaluation.
xs, x̄s = collect(zip(xx̄s...))
Ω, Δx_rules = ChainRules.rrule(f, xs...)
@test f(xs...) == Ω
y, rules = rrule(f, xs...)
@test f(xs...) == y

# Correctness testing via finite differencing.
Δxs_ad = map(Δx_rule->Δx_rule(ȳ), Δx_rules)
Δxs_fd = j′vp(fdm, f, ȳ, xs...)
for (Δx_ad, Δx_fd) in zip(Δxs_ad, Δxs_fd)
@test isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...)
x̄s_ad = map(rules) do rule
rule isa DNERule ? DNE() : rule(ȳ)
end
x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing)
for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd)
if x̄_fd === nothing
# The way we've structured the above, this tests that the rule is a DNERule
@test x̄_ad isa DNE
else
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
for (x̄, Δx_rule, Δx_ad) in zip(x̄s, Δx_rules, Δxs_ad)
test_accumulation(x̄, Δx_rule, ȳ, Δx_ad)
test_accumulation(Zero(), Δx_rule, ȳ, Δx_ad)
for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad)
=== nothing && continue
test_accumulation(x̄, rule, ȳ, x̄_ad)
test_accumulation(Zero(), rule, ȳ, x̄_ad)
end
end

Expand Down

0 comments on commit 4e7a61c

Please sign in to comment.