Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for all Symmetric/Hermitian constructors #182

Merged
merged 24 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.4"
version = "0.7.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
80 changes: 73 additions & 7 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,84 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
end

#####
##### `Symmetric`
##### `Symmetric`/`Hermitian`
#####

function rrule(::Type{<:Symmetric}, A::AbstractMatrix)
function Symmetric_pullback(ȳ)
return (NO_FIELDS, @thunk(_symmetric_back(ȳ)))
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
return T(A, uplo), T(ΔA, uplo)
end

function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Ω = T(A, uplo)
function HermOrSym_pullback(ΔΩ)
return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a thunk here? The cotangent w.r.t. A is the only meaningful, so I would imagine that it would always get used somewhere, but my intuition for this isn't great, perhaps yours is better.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not? I had thought that the convention was to always thunk in reverse unless 1) there's only one cotangent (not counting the NO_FIELDS) or 2) sometimes for scalar functions (e.g. those made with @scalar_rule). I personally think modifying (1) to be "there's only one cotangent that a user could reasonably want" makes perfect sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oxinabox do you agree? If so I'll open a PR to the docs to update guidance.

end
return Ω, HermOrSym_pullback
end

function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
return TM(A), TM(_symherm_forward(A, ΔA))
end
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
return Array(A), Array(_symherm_forward(A, ΔA))
end

function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
function Matrix_pullback(ΔΩ)
TA = _symhermtype(A)
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
uplo = A.uplo
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
return NO_FIELDS, ∂A
end
return Symmetric(A), Symmetric_pullback
return TM(A), Matrix_pullback
end
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + transpose(LowerTriangular(ΔΩ)) - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
function _symherm_forward(A, ΔA)
TA = _symhermtype(A)
return if ΔA isa TA
ΔA
else
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
end
end

# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
return _symmetric_back(ΔΩ, uplo)
end
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)

function _symmetric_back(ΔΩ, uplo)
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
end
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)

function _hermitian_back(ΔΩ, uplo)
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
end
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
return if istriu(ΔΩ)
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
else
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end

#####
##### `Adjoint`
Expand Down
2 changes: 2 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
return X
end
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads

_extract_imag(x) = complex(0, imag(x))
42 changes: 40 additions & 2 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,47 @@
end
end
end
@testset "Symmetric(::AbstractMatrix{$T})" for T in (Float64, ComplexF64)
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
@testset "frule" begin
x = randn(T, N, N)
Δx = randn(T, N, N)
# can't use frule_test here because it doesn't yet ignore nothing tangents
Ω = SymHerm(x, uplo)
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
@test Ω_ad == Ω
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
@test ∂Ω_ad ≈ ∂Ω_fd
end
@testset "rrule" begin
x = randn(T, N, N)
∂x = randn(T, N, N)
ΔΩ = randn(T, N, N)
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
end
@testset "back(::Diagonal)" begin
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
end
end
end
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
SymHerm in (Symmetric, Hermitian),
T in (Float64, ComplexF64),
uplo in (:U, :L)

N = 3
rrule_test(Symmetric, randn(T, N, N), (randn(T, N, N), randn(T, N, N)))
x = SymHerm(randn(T, N, N), uplo)
Δx = randn(T, N, N)
∂x = SymHerm(randn(T, N, N), uplo)
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
frule_test(f, (x, Δx))
frule_test(f, (x, SymHerm(Δx, uplo)))
rrule_test(f, ΔΩ, (x, ∂x))
end
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
n = 5
Expand Down