Skip to content

Commit

Permalink
Add rrule for mean
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 19, 2019
1 parent e4319c8 commit 1459fe6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Cassette = "^0.2"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ChainRules

using Cassette
using LinearAlgebra
using Statistics
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable

import NaNMath, SpecialFunctions, LinearAlgebra, LinearAlgebra.BLAS
Expand Down
22 changes: 22 additions & 0 deletions src/rules/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,25 @@ function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
∂x = Rule(ȳ -> 2.* x)
return y, (DNERule(), ∂x)
end

#####
##### `mean`
#####

_denom(x, dims::Colon) = length(x)
_denom(x, dims::Integer) = size(x, dims)
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)

# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36

function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
_, dx = rrule(sum, x; dims=dims)
n = _denom(x, dims)
return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n)
end

function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
_, (_, dx) = rrule(sum, f, x)
n = _denom(x, :)
return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n))
end
11 changes: 11 additions & 0 deletions test/rules/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,15 @@
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
end
end
@testset "mean" begin
rng = MersenneTwister(999)
n = 9
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
X = randn(rng, n, n)
y, dX = rrule(mean, X; dims=1)
= randn(rng, size(y))
X̄_ad = dX(ȳ)
X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X)
@test X̄_ad X̄_fd rtol=1e-9 atol=1e-9
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: more tests!

using ChainRules, Test, FDM, LinearAlgebra, Random
using ChainRules, Test, FDM, LinearAlgebra, Random, Statistics
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
Expand Down

0 comments on commit 1459fe6

Please sign in to comment.