diff --git a/src/ChainRules.jl b/src/ChainRules.jl index cf812063b..4aeb988bf 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -13,6 +13,7 @@ include("rules.jl") include("rules/base.jl") include("rules/array.jl") include("rules/broadcast.jl") +include("rules/mapreduce.jl") include("rules/linalg/utils.jl") include("rules/linalg/blas.jl") include("rules/linalg/dense.jl") diff --git a/src/rules/linalg/dense.jl b/src/rules/linalg/dense.jl index 6f86d25a3..9eb3ee168 100644 --- a/src/rules/linalg/dense.jl +++ b/src/rules/linalg/dense.jl @@ -4,14 +4,6 @@ using LinearAlgebra: AbstractTriangular # these we can use simpler definitions for `/` and `\`. const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} -##### -##### `sum` -##### - -frule(::typeof(sum), x) = (sum(x), Rule(sum)) - -rrule(::typeof(sum), x) = (sum(x), Rule(cast)) - ##### ##### `dot` ##### diff --git a/src/rules/mapreduce.jl b/src/rules/mapreduce.jl new file mode 100644 index 000000000..25bcfd3fa --- /dev/null +++ b/src/rules/mapreduce.jl @@ -0,0 +1,29 @@ +##### +##### `map` +##### + +function rrule(::typeof(map), f, xs...) + y = map(f, xs...) + ∂xs = ntuple(length(xs)) do i + Rule() do ȳ + map(ȳ, xs...) do ȳi, xis... + r = rrule(f, xis...) + if r === nothing + throw(ArgumentError("can't differentiate `map` with `$f`; no `rrule` " * + "is defined for `$f$xis`")) + end + _, ∂xis = r + extern(∂xis[i](ȳi)) + end + end + end + return y, (DNERule(), ∂xs...) +end + +##### +##### `sum` +##### + +frule(::typeof(sum), x) = (sum(x), Rule(sum)) + +rrule(::typeof(sum), x) = (sum(x), Rule(cast)) diff --git a/test/rules/mapreduce.jl b/test/rules/mapreduce.jl new file mode 100644 index 000000000..6d88b8c30 --- /dev/null +++ b/test/rules/mapreduce.jl @@ -0,0 +1,11 @@ +@testset "Maps and Reductions" begin + @testset "map" begin + rng = MersenneTwister(42) + n = 10 + x = randn(rng, n) + vx = randn(rng, n) + ȳ = randn(rng, n) + rrule_test(map, ȳ, (sin, nothing), (x, vx)) + rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n))) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3f474becc..df2281c13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using ChainRules, Test, FDM, LinearAlgebra, Random using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger, Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted, - DNE, Thunk, Casted + DNE, Thunk, Casted, DNERule using Base.Broadcast: broadcastable include("test_util.jl") @@ -15,6 +15,7 @@ include("test_util.jl") @testset "rules" begin include(joinpath("rules", "base.jl")) include(joinpath("rules", "array.jl")) + include(joinpath("rules", "mapreduce.jl")) @testset "linalg" begin include(joinpath("rules", "linalg", "dense.jl")) include(joinpath("rules", "linalg", "structured.jl")) diff --git a/test/test_util.jl b/test/test_util.jl index a45c121f2..f6cd9992a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -50,23 +50,58 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm test_accumulation(Zero(), dx, ȳ, x̄_ad) end +function _make_fdm_call(fdm, f, ȳ, xs, ignores) + sig = Expr(:tuple) + call = Expr(:call, f) + newxs = Any[] + arginds = Int[] + i = 1 + for (x, ignore) in zip(xs, ignores) + if ignore + push!(call.args, x) + else + push!(call.args, Symbol(:x, i)) + push!(sig.args, Symbol(:x, i)) + push!(newxs, x) + push!(arginds, i) + end + i += 1 + end + fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...))) + fd = eval(fdexpr) + fd isa Tuple || (fd = (fd,)) + args = Any[nothing for _ in 1:length(xs)] + for (dx, ind) in zip(fd, arginds) + args[ind] = dx + end + return (args...,) +end + function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) # Check correctness of evaluation. xs, x̄s = collect(zip(xx̄s...)) - Ω, Δx_rules = ChainRules.rrule(f, xs...) - @test f(xs...) == Ω + y, rules = rrule(f, xs...) + @test f(xs...) == y # Correctness testing via finite differencing. - Δxs_ad = map(Δx_rule->Δx_rule(ȳ), Δx_rules) - Δxs_fd = j′vp(fdm, f, ȳ, xs...) - for (Δx_ad, Δx_fd) in zip(Δxs_ad, Δxs_fd) - @test isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...) + x̄s_ad = map(rules) do rule + rule isa DNERule ? DNE() : rule(ȳ) + end + x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) + for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) + if x̄_fd === nothing + # The way we've structured the above, this tests that the rule is a DNERule + @test x̄_ad isa DNE + else + @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) + end end # Assuming the above to be correct, check that other ChainRules mechanisms are correct. - for (x̄, Δx_rule, Δx_ad) in zip(x̄s, Δx_rules, Δxs_ad) - test_accumulation(x̄, Δx_rule, ȳ, Δx_ad) - test_accumulation(Zero(), Δx_rule, ȳ, Δx_ad) + for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) + x̄ === nothing && continue + test_accumulation(x̄, rule, ȳ, x̄_ad) + test_accumulation(Zero(), rule, ȳ, x̄_ad) end end