Skip to content

Commit

Permalink
Add rules for all Symmetric/Hermitian constructors (#182)
Browse files Browse the repository at this point in the history
* Expand rule for lower symmetric constructor

* Expand tests to lower and complex

* Add rule for Hermitian constructors

* Unify Symmetric and Hermitian rules

* Unify symmetric and hermitian tests

* Add frule

* Reformat

* Add rule for conversion to matrix

* Add rrule for Array

* Add frules for Array and Matrix

* Increment version number

* Increment version number

* Add methods with matrix args

* Dispatch on realness

* Call Matrix instead of collect

* Add coments

* Remove type constraints

* Apply suggestions from code review

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>

* Wrap lines

* Use more informative type names

* Increment version number

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
  • Loading branch information
sethaxen and willtebbutt authored Jul 6, 2020
1 parent 697e7e4 commit eabdd74
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.4"
version = "0.7.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
80 changes: 73 additions & 7 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,84 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
end

#####
##### `Symmetric`
##### `Symmetric`/`Hermitian`
#####

function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
function Symmetric_pullback(ȳ)
return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
return T(A, uplo), T(ΔA, uplo)
end

function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Ω = T(A, uplo)
function HermOrSym_pullback(ΔΩ)
return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist())
end
return Ω, HermOrSym_pullback
end

function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
return TM(A), TM(_symherm_forward(A, ΔA))
end
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
return Array(A), Array(_symherm_forward(A, ΔA))
end

function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
function Matrix_pullback(ΔΩ)
TA = _symhermtype(A)
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
uplo = A.uplo
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
return NO_FIELDS, ∂A
end
return Symmetric(A), Symmetric_pullback
return TM(A), Matrix_pullback
end
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + transpose(LowerTriangular(ΔΩ)) - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
function _symherm_forward(A, ΔA)
TA = _symhermtype(A)
return if ΔA isa TA
ΔA
else
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
end
end

# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
return _symmetric_back(ΔΩ, uplo)
end
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)

function _symmetric_back(ΔΩ, uplo)
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
end
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)

function _hermitian_back(ΔΩ, uplo)
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
end
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
return if istriu(ΔΩ)
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
else
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end

#####
##### `Adjoint`
Expand Down
2 changes: 2 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
return X
end
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads

_extract_imag(x) = complex(0, imag(x))
42 changes: 40 additions & 2 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,47 @@
end
end
end
@testset "Symmetric(::AbstractMatrix{$T})" for T in (Float64, ComplexF64)
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
@testset "frule" begin
x = randn(T, N, N)
Δx = randn(T, N, N)
# can't use frule_test here because it doesn't yet ignore nothing tangents
Ω = SymHerm(x, uplo)
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
@test Ω_ad == Ω
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
@test ∂Ω_ad ∂Ω_fd
end
@testset "rrule" begin
x = randn(T, N, N)
∂x = randn(T, N, N)
ΔΩ = randn(T, N, N)
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
end
@testset "back(::Diagonal)" begin
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
end
end
end
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
rrule_test(Symmetric, randn(T, N, N), (randn(T, N, N), randn(T, N, N)))
x = SymHerm(randn(T, N, N), uplo)
Δx = randn(T, N, N)
∂x = SymHerm(randn(T, N, N), uplo)
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
frule_test(f, (x, Δx))
frule_test(f, (x, SymHerm(Δx, uplo)))
rrule_test(f, ΔΩ, (x, ∂x))
end
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
n = 5
Expand Down

2 comments on commit eabdd74

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17501

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.5 -m "<description of version>" eabdd74785a4588f349b51da82f174f8daf2d32f
git push origin v0.7.5

Please sign in to comment.