Skip to content

Commit

Permalink
Add rrules for binary linear algebra operations (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan authored Jun 13, 2019
1 parent dc4adb0 commit 314b08a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Cassette = "^0.2"
FDM = "^0.5"
FDM = "^0.6"
julia = "^1.0"

[extras]
Expand Down
77 changes: 77 additions & 0 deletions src/rules/linalg/dense.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
using LinearAlgebra: AbstractTriangular

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}

#####
##### `sum`
#####
Expand Down Expand Up @@ -69,3 +75,74 @@ end
frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx))))

rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))

#####
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, (Rule(Ȳ ->* B'), Rule(Ȳ -> A' * Ȳ))
end

#####
##### `/`
#####

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
Y = A / B
S = T.name.wrapper
∂A = Rule(Ȳ ->/ B')
∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B')))
return Y, (∂A, ∂B)
end

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA = rrule(adjoint, A)
Bᵀ, dB = rrule(adjoint, B)
Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ)
C, dC = rrule(adjoint, Cᵀ)
∂A = Rule(dAdAᵀdC)
∂B = Rule(dAdBᵀdC)
return C, (∂A, ∂B)
end

#####
##### `\`
#####

function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
Y = A \ B
S = T.name.wrapper
∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y'))
∂B = Rule(Ȳ -> A' \ Ȳ)
return Y, (∂A, ∂B)
end

function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B
∂A = Rule() do
= A' \
= -* Y'
_add!(Ā, (B - A * Y) *' / A')
_add!(Ā, A' \ Y * (Ȳ' -'A))
end
∂B = Rule(Ȳ -> A' \ Ȳ)
return Y, (∂A, ∂B)
end

#####
##### `norm`
#####

function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2)
y = norm(A, p)
u = y^(1-p)
∂A = Rule(ȳ ->.* u .* abs.(A).^p ./ A)
∂p = Rule(ȳ ->* (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p)
return y, (∂A, ∂p)
end

function rrule(::typeof(norm), x::Real, p::Real=2)
return norm(x, p), (Rule(ȳ ->* sign(x)), Rule(_ -> zero(x)))
end
64 changes: 64 additions & 0 deletions test/rules/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,68 @@ end
frule_test(tr, (randn(rng, N, N), randn(rng, N, N)))
rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N)))
end
@testset "*" begin
rng = MersenneTwister(123456)
dims = [3,4,5]
for n in dims, m in dims, p in dims
n > 3 && n == m == p && continue # don't need to test square case multiple times
A = randn(rng, m, n)
B = randn(rng, n, p)
= randn(rng, m, p)
rrule_test(*, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, n, p)))
end
end
@testset "$f" for f in [/, \]
rng = MersenneTwister(42)
for n in 3:5, m in 3:5
A = randn(rng, m, n)
B = randn(rng, m, n)
= randn(rng, size(f(A, B)))
rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n)))
end
# Vectors
x = randn(rng, 10)
y = randn(rng, 10)
= randn(rng, size(f(x, y))...)
rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10)))
if f == (/)
@testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10)))
Y = randn(rng, 5, 10)
= randn(rng, size(f(Y, RHS))...)
rrule_test(f, Ȳ, (Y, randn(rng, size(Y))), (RHS, randn(rng, size(RHS))))
end
else
@testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
LHS = T(randn(rng, T == Diagonal ? 10 : (10, 10)))
y = randn(rng, 10)
= randn(rng, size(f(LHS, y))...)
rrule_test(f, ȳ, (LHS, randn(rng, size(LHS))), (y, randn(rng, 10)))
Y = randn(rng, 10, 10)
= randn(rng, 10, 10)
rrule_test(f, Ȳ, (LHS, randn(rng, size(LHS))), (Y, randn(rng, size(Y))))
end
@testset "Matrix $f Vector" begin
X = randn(rng, 10, 4)
y = randn(rng, 10)
= randn(rng, size(f(X, y))...)
rrule_test(f, ȳ, (X, randn(rng, size(X))), (y, randn(rng, 10)))
end
@testset "Vector $f Matrix" begin
x = randn(rng, 10)
Y = randn(rng, 10, 4)
= randn(rng, size(f(x, Y))...)
rrule_test(f, ȳ, (x, randn(rng, size(x))), (Y, randn(rng, size(Y))))
end
end
end
@testset "norm" begin
rng = MersenneTwister(3)
for dims in [(), (5,), (3, 2), (7, 3, 2)]
A = randn(rng, dims...)
p = randn(rng)
= randn(rng)
rrule_test(norm, ȳ, (A, randn(rng, dims...)), (p, randn(rng)))
end
end
end

0 comments on commit 314b08a

Please sign in to comment.