From 48804fdb36e152d9371663f49d54473927646303 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Tue, 18 Jun 2019 22:29:06 -0700 Subject: [PATCH] Add rrule for mean --- Project.toml | 1 + src/ChainRules.jl | 1 + src/rules/mapreduce.jl | 22 ++++++++++++++++++++++ test/rules/mapreduce.jl | 11 +++++++++++ test/runtests.jl | 2 +- 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cd0497b6e..7dc3acd5d 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Cassette = "7057c7e9-c182-5462-911a-8362d720325c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Cassette = "^0.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 00a6ff49a..2277699c7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -3,6 +3,7 @@ module ChainRules using Cassette using LinearAlgebra using LinearAlgebra.BLAS +using Statistics using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable if VERSION < v"1.3.0-DEV.142" diff --git a/src/rules/mapreduce.jl b/src/rules/mapreduce.jl index b0e6a006b..69a9b9b2b 100644 --- a/src/rules/mapreduce.jl +++ b/src/rules/mapreduce.jl @@ -62,3 +62,25 @@ function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) ∂x = Rule(ȳ -> 2ȳ .* x) return y, (DNERule(), ∂x) end + +##### +##### `mean` +##### + +_denom(x, dims::Colon) = length(x) +_denom(x, dims::Integer) = size(x, dims) +_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) + +# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36 + +function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) + _, dx = rrule(sum, x; dims=dims) + n = _denom(x, dims) + return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n) +end + +function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) + _, (_, dx) = rrule(sum, f, x) + n = _denom(x, :) + return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n)) +end diff --git a/test/rules/mapreduce.jl b/test/rules/mapreduce.jl index 61aff89eb..8ee0a1e1d 100644 --- a/test/rules/mapreduce.jl +++ b/test/rules/mapreduce.jl @@ -64,4 +64,15 @@ @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 end end + @testset "mean" begin + rng = MersenneTwister(999) + n = 9 + rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n))) + X = randn(rng, n, n) + y, dX = rrule(mean, X; dims=1) + ȳ = randn(rng, size(y)) + X̄_ad = dX(ȳ) + X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) + @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 + end end diff --git a/test/runtests.jl b/test/runtests.jl index df2281c13..b1b4d5005 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ # TODO: more tests! -using ChainRules, Test, FDM, LinearAlgebra, Random +using ChainRules, Test, FDM, LinearAlgebra, Random, Statistics using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger, Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,