From 65df9875bd74af5be3f09e9d8aaf576296dd4226 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Wed, 18 Nov 2020 02:32:19 +0530 Subject: [PATCH] add rules for vector-matrix and matrix-vector product (#305) * add rules for vector-matrix and matrix-vector product Fixes #276 * fix bug in test function writing * add separate dispatch for Vector * Matrix * fix tests for Matrix*Vector, Vector*Matrix * fix test * Assert about size * bump version Co-authored-by: Lyndon White --- Project.toml | 2 +- src/rulesets/Base/arraymath.jl | 26 ++++++++++++++++++++++++-- test/rulesets/Base/arraymath.jl | 21 ++++++++++++++++++++- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 562430172..8eb6e4a35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.32" +version = "0.7.33" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index e8eae0acc..023ccf7f4 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -21,8 +21,8 @@ end function rrule( ::typeof(*), - A::AbstractMatrix{<:CommutativeMulNumber}, - B::AbstractMatrix{<:CommutativeMulNumber}, + A::AbstractVecOrMat{<:CommutativeMulNumber}, + B::AbstractVecOrMat{<:CommutativeMulNumber}, ) function times_pullback(Ȳ) return ( @@ -40,6 +40,28 @@ function rrule( return A * B, times_pullback end +function rrule( + ::typeof(*), + A::AbstractVector{<:CommutativeMulNumber}, + B::AbstractMatrix{<:CommutativeMulNumber}, +) + function times_pullback(Ȳ) + @assert size(B, 1) === 1 # otherwise primal would have failed. + return ( + NO_FIELDS, + InplaceableThunk( + @thunk(Ȳ * vec(B')), + X̄ -> mul!(X̄, Ȳ, vec(B'), true, true) + ), + InplaceableThunk( + @thunk(A' * Ȳ), + X̄ -> mul!(X̄, A', Ȳ, true, true) + ) + ) + end + return A * B, times_pullback +end + function rrule( ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} ) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 8bc996860..4c205e784 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -20,8 +20,20 @@ rrule_test(*, ⋆(dims), ⋆₂(dims), ⋆₂()) end + @testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5) + @testset "Array" begin + rrule_test(*, ⋆(n), n ⋆₂ m, ⋆₂(m)) + end + end + + @testset "AbstractVector-AbstractMatrix n=$n, m=$m" for n in (2, 3), m in (4, 5) + @testset "Array" begin + rrule_test(*, n ⋆ m, ⋆₂(n), 1 ⋆₂ m) + end + end + @testset "AbstractMatrix-AbstractMatrix" begin - @testset "n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) + @testset "Matrix * Matrix n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3) @testset "Array" begin rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) end @@ -46,6 +58,13 @@ end end end + + @testset "Covector * Vector n=$n" for n in (3, 5) + @testset "$f" for f in (adjoint, transpose) + # This should be same as dot product and give a scalar + rrule_test(*, ⋆(), f.(⋆₂(n)), ⋆₂(n)) + end + end end