From eabdd74785a4588f349b51da82f174f8daf2d32f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 5 Jul 2020 17:43:46 -0700 Subject: [PATCH] Add rules for all Symmetric/Hermitian constructors (#182) * Expand rule for lower symmetric constructor * Expand tests to lower and complex * Add rule for Hermitian constructors * Unify Symmetric and Hermitian rules * Unify symmetric and hermitian tests * Add frule * Reformat * Add rule for conversion to matrix * Add rrule for Array * Add frules for Array and Matrix * Increment version number * Increment version number * Add methods with matrix args * Dispatch on realness * Call Matrix instead of collect * Add coments * Remove type constraints * Apply suggestions from code review Co-authored-by: willtebbutt * Wrap lines * Use more informative type names * Increment version number Co-authored-by: willtebbutt --- Project.toml | 2 +- src/rulesets/LinearAlgebra/structured.jl | 80 +++++++++++++++++++++-- src/rulesets/LinearAlgebra/utils.jl | 2 + test/rulesets/LinearAlgebra/structured.jl | 42 +++++++++++- 4 files changed, 116 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 3187acf10..b99ae1b73 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.4" +version = "0.7.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index f61a7a7fb..677039743 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -63,18 +63,84 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) end ##### -##### `Symmetric` +##### `Symmetric`/`Hermitian` ##### -function rrule(::Type{<:Symmetric}, A::AbstractMatrix) - function Symmetric_pullback(ȳ) - return (NO_FIELDS, @thunk(_symmetric_back(ȳ))) +function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) + return T(A, uplo), T(ΔA, uplo) +end + +function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) + Ω = T(A, uplo) + function HermOrSym_pullback(ΔΩ) + return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist()) + end + return Ω, HermOrSym_pullback +end + +function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) + return TM(A), TM(_symherm_forward(A, ΔA)) +end +function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym) + return Array(A), Array(_symherm_forward(A, ΔA)) +end + +function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) + function Matrix_pullback(ΔΩ) + TA = _symhermtype(A) + T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} + uplo = A.uplo + ∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo) + return NO_FIELDS, ∂A end - return Symmetric(A), Symmetric_pullback + return TM(A), Matrix_pullback end +rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A) -_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + transpose(LowerTriangular(ΔΩ)) - Diagonal(ΔΩ) -_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ +# Get type (Symmetric or Hermitian) from type or matrix +_symhermtype(::Type{<:Symmetric}) = Symmetric +_symhermtype(::Type{<:Hermitian}) = Hermitian +_symhermtype(A) = _symhermtype(typeof(A)) + +# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω +function _symherm_forward(A, ΔA) + TA = _symhermtype(A) + return if ΔA isa TA + ΔA + else + TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo) + end +end + +# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A +_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo) +function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo) + return _symmetric_back(ΔΩ, uplo) +end +_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo) +_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo) + +function _symmetric_back(ΔΩ, uplo) + L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ) + return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D +end +_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ +_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ)) +_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ) + +function _hermitian_back(ΔΩ, uplo) + L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ)) + return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD +end +_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ) +function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) + ∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ))) + return if istriu(ΔΩ) + return Matrix(uplo == 'U' ? ∂UL : ∂UL') + else + return Matrix(uplo == 'U' ? ∂UL' : ∂UL) + end +end ##### ##### `Adjoint` diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index caf661447..64dcdf7dc 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -34,3 +34,5 @@ function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat) return X end _add!(X, Y) = X + Y # handles all `AbstractZero` overloads + +_extract_imag(x) = complex(0, imag(x)) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 46fd77920..8ca1f0284 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -78,9 +78,47 @@ end end end - @testset "Symmetric(::AbstractMatrix{$T})" for T in (Float64, ComplexF64) + @testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + N = 3 + @testset "frule" begin + x = randn(T, N, N) + Δx = randn(T, N, N) + # can't use frule_test here because it doesn't yet ignore nothing tangents + Ω = SymHerm(x, uplo) + Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo) + @test Ω_ad == Ω + ∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx)) + @test ∂Ω_ad ≈ ∂Ω_fd + end + @testset "rrule" begin + x = randn(T, N, N) + ∂x = randn(T, N, N) + ΔΩ = randn(T, N, N) + @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular) + rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing)) + end + @testset "back(::Diagonal)" begin + rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing)) + end + end + end + @testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array), + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + N = 3 - rrule_test(Symmetric, randn(T, N, N), (randn(T, N, N), randn(T, N, N))) + x = SymHerm(randn(T, N, N), uplo) + Δx = randn(T, N, N) + ∂x = SymHerm(randn(T, N, N), uplo) + ΔΩ = f(SymHerm(randn(T, N, N), uplo)) + frule_test(f, (x, Δx)) + frule_test(f, (x, SymHerm(Δx, uplo))) + rrule_test(f, ΔΩ, (x, ∂x)) end @testset "$f" for f in (Adjoint, adjoint, Transpose, transpose) n = 5