Skip to content

Commit eabdd74

Browse files
Add rules for all Symmetric/Hermitian constructors (#182)
* Expand rule for lower symmetric constructor * Expand tests to lower and complex * Add rule for Hermitian constructors * Unify Symmetric and Hermitian rules * Unify symmetric and hermitian tests * Add frule * Reformat * Add rule for conversion to matrix * Add rrule for Array * Add frules for Array and Matrix * Increment version number * Increment version number * Add methods with matrix args * Dispatch on realness * Call Matrix instead of collect * Add coments * Remove type constraints * Apply suggestions from code review Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> * Wrap lines * Use more informative type names * Increment version number Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
1 parent 697e7e4 commit eabdd74

File tree

4 files changed

+116
-10
lines changed

4 files changed

+116
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.4"
3+
version = "0.7.5"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,84 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
6363
end
6464

6565
#####
66-
##### `Symmetric`
66+
##### `Symmetric`/`Hermitian`
6767
#####
6868

69-
function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
70-
function Symmetric_pullback(ȳ)
71-
return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
69+
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
70+
return T(A, uplo), T(ΔA, uplo)
71+
end
72+
73+
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
74+
Ω = T(A, uplo)
75+
function HermOrSym_pullback(ΔΩ)
76+
return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist())
77+
end
78+
return Ω, HermOrSym_pullback
79+
end
80+
81+
function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
82+
return TM(A), TM(_symherm_forward(A, ΔA))
83+
end
84+
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
85+
return Array(A), Array(_symherm_forward(A, ΔA))
86+
end
87+
88+
function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
89+
function Matrix_pullback(ΔΩ)
90+
TA = _symhermtype(A)
91+
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
92+
uplo = A.uplo
93+
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
94+
return NO_FIELDS, ∂A
7295
end
73-
return Symmetric(A), Symmetric_pullback
96+
return TM(A), Matrix_pullback
7497
end
98+
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)
7599

76-
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + transpose(LowerTriangular(ΔΩ)) - Diagonal(ΔΩ)
77-
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
100+
# Get type (Symmetric or Hermitian) from type or matrix
101+
_symhermtype(::Type{<:Symmetric}) = Symmetric
102+
_symhermtype(::Type{<:Hermitian}) = Hermitian
103+
_symhermtype(A) = _symhermtype(typeof(A))
104+
105+
# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
106+
function _symherm_forward(A, ΔA)
107+
TA = _symhermtype(A)
108+
return if ΔA isa TA
109+
ΔA
110+
else
111+
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
112+
end
113+
end
114+
115+
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
116+
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
117+
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
118+
return _symmetric_back(ΔΩ, uplo)
119+
end
120+
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
121+
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)
122+
123+
function _symmetric_back(ΔΩ, uplo)
124+
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
125+
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
126+
end
127+
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
128+
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
129+
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)
130+
131+
function _hermitian_back(ΔΩ, uplo)
132+
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
133+
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
134+
end
135+
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
136+
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
137+
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
138+
return if istriu(ΔΩ)
139+
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
140+
else
141+
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
142+
end
143+
end
78144

79145
#####
80146
##### `Adjoint`

src/rulesets/LinearAlgebra/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
3434
return X
3535
end
3636
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads
37+
38+
_extract_imag(x) = complex(0, imag(x))

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,47 @@
7878
end
7979
end
8080
end
81-
@testset "Symmetric(::AbstractMatrix{$T})" for T in (Float64, ComplexF64)
81+
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
82+
SymHerm in (Symmetric, Hermitian),
83+
T in (Float64, ComplexF64),
84+
uplo in (:U, :L)
85+
86+
N = 3
87+
@testset "frule" begin
88+
x = randn(T, N, N)
89+
Δx = randn(T, N, N)
90+
# can't use frule_test here because it doesn't yet ignore nothing tangents
91+
Ω = SymHerm(x, uplo)
92+
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
93+
@test Ω_ad == Ω
94+
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
95+
@test ∂Ω_ad ∂Ω_fd
96+
end
97+
@testset "rrule" begin
98+
x = randn(T, N, N)
99+
∂x = randn(T, N, N)
100+
ΔΩ = randn(T, N, N)
101+
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
102+
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
103+
end
104+
@testset "back(::Diagonal)" begin
105+
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
106+
end
107+
end
108+
end
109+
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
110+
SymHerm in (Symmetric, Hermitian),
111+
T in (Float64, ComplexF64),
112+
uplo in (:U, :L)
113+
82114
N = 3
83-
rrule_test(Symmetric, randn(T, N, N), (randn(T, N, N), randn(T, N, N)))
115+
x = SymHerm(randn(T, N, N), uplo)
116+
Δx = randn(T, N, N)
117+
∂x = SymHerm(randn(T, N, N), uplo)
118+
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
119+
frule_test(f, (x, Δx))
120+
frule_test(f, (x, SymHerm(Δx, uplo)))
121+
rrule_test(f, ΔΩ, (x, ∂x))
84122
end
85123
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
86124
n = 5

0 commit comments

Comments
 (0)