diff --git a/Project.toml b/Project.toml index f5c94dc2a..cd0497b6e 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Cassette = "^0.2" -FDM = "^0.5" +FDM = "^0.6" julia = "^1.0" [extras] diff --git a/src/rules/linalg/dense.jl b/src/rules/linalg/dense.jl index 48a86de7c..6f86d25a3 100644 --- a/src/rules/linalg/dense.jl +++ b/src/rules/linalg/dense.jl @@ -1,3 +1,9 @@ +using LinearAlgebra: AbstractTriangular + +# Matrix wrapper types that we know are square and are thus potentially invertible. For +# these we can use simpler definitions for `/` and `\`. +const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} + ##### ##### `sum` ##### @@ -69,3 +75,74 @@ end frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx)))) rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1))))) + +##### +##### `*` +##### + +function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) + return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) +end + +##### +##### `/` +##### + +function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} + Y = A / B + S = T.name.wrapper + ∂A = Rule(Ȳ -> Ȳ / B') + ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) + return Y, (∂A, ∂B) +end + +function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Aᵀ, dA = rrule(adjoint, A) + Bᵀ, dB = rrule(adjoint, B) + Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) + C, dC = rrule(adjoint, Cᵀ) + ∂A = Rule(dA∘dAᵀ∘dC) + ∂B = Rule(dA∘dBᵀ∘dC) + return C, (∂A, ∂B) +end + +##### +##### `\` +##### + +function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} + Y = A \ B + S = T.name.wrapper + ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) + ∂B = Rule(Ȳ -> A' \ Ȳ) + return Y, (∂A, ∂B) +end + +function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Y = A \ B + ∂A = Rule() do Ȳ + B̄ = A' \ Ȳ + Ā = -B̄ * Y' + _add!(Ā, (B - A * Y) * B̄' / A') + _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā + end + ∂B = Rule(Ȳ -> A' \ Ȳ) + return Y, (∂A, ∂B) +end + +##### +##### `norm` +##### + +function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) + y = norm(A, p) + u = y^(1-p) + ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) + ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) + return y, (∂A, ∂p) +end + +function rrule(::typeof(norm), x::Real, p::Real=2) + return norm(x, p), (Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) +end diff --git a/test/rules/linalg/dense.jl b/test/rules/linalg/dense.jl index 097fd598d..0eb61cc6f 100644 --- a/test/rules/linalg/dense.jl +++ b/test/rules/linalg/dense.jl @@ -70,4 +70,68 @@ end frule_test(tr, (randn(rng, N, N), randn(rng, N, N))) rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N))) end + @testset "*" begin + rng = MersenneTwister(123456) + dims = [3,4,5] + for n in dims, m in dims, p in dims + n > 3 && n == m == p && continue # don't need to test square case multiple times + A = randn(rng, m, n) + B = randn(rng, n, p) + Ȳ = randn(rng, m, p) + rrule_test(*, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, n, p))) + end + end + @testset "$f" for f in [/, \] + rng = MersenneTwister(42) + for n in 3:5, m in 3:5 + A = randn(rng, m, n) + B = randn(rng, m, n) + Ȳ = randn(rng, size(f(A, B))) + rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) + end + # Vectors + x = randn(rng, 10) + y = randn(rng, 10) + ȳ = randn(rng, size(f(x, y))...) + rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) + if f == (/) + @testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) + RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) + Y = randn(rng, 5, 10) + Ȳ = randn(rng, size(f(Y, RHS))...) + rrule_test(f, Ȳ, (Y, randn(rng, size(Y))), (RHS, randn(rng, size(RHS)))) + end + else + @testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular) + LHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) + y = randn(rng, 10) + ȳ = randn(rng, size(f(LHS, y))...) + rrule_test(f, ȳ, (LHS, randn(rng, size(LHS))), (y, randn(rng, 10))) + Y = randn(rng, 10, 10) + Ȳ = randn(rng, 10, 10) + rrule_test(f, Ȳ, (LHS, randn(rng, size(LHS))), (Y, randn(rng, size(Y)))) + end + @testset "Matrix $f Vector" begin + X = randn(rng, 10, 4) + y = randn(rng, 10) + ȳ = randn(rng, size(f(X, y))...) + rrule_test(f, ȳ, (X, randn(rng, size(X))), (y, randn(rng, 10))) + end + @testset "Vector $f Matrix" begin + x = randn(rng, 10) + Y = randn(rng, 10, 4) + ȳ = randn(rng, size(f(x, Y))...) + rrule_test(f, ȳ, (x, randn(rng, size(x))), (Y, randn(rng, size(Y)))) + end + end + end + @testset "norm" begin + rng = MersenneTwister(3) + for dims in [(), (5,), (3, 2), (7, 3, 2)] + A = randn(rng, dims...) + p = randn(rng) + ȳ = randn(rng) + rrule_test(norm, ȳ, (A, randn(rng, dims...)), (p, randn(rng))) + end + end end