Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ProjectTo convert Tangent back to Diagonal, etc, when safe #446

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 109 additions & 50 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ end

using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

const ArrayOrZero = Union{AbstractArray, AbstractZero}

# UniformScaling can represent its own cotangent
ProjectTo(x::UniformScaling) = ProjectTo{UniformScaling}(; λ=ProjectTo(x.λ))
ProjectTo(x::UniformScaling{Bool}) = ProjectTo(false)
Expand All @@ -401,6 +403,17 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
return adjoint(project.parent(dy))
end
# structural => natural standardisation, broadest possible signature
function (project::ProjectTo{Adjoint})(dx::Tangent{<:Adjoint})
if dx.parent isa ArrayOrZero
# Adjoint handles ZeroTangent, which could also be produced by project.parent
return Adjoint(project.parent(dx.parent))
else
# Can't wrap a structural representation, or a thunk, in an Adjoint.
# But do these happen?
return dx
end
end

function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
return ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
Expand All @@ -415,11 +428,21 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
return transpose(project.parent(dy))
end
function (project::ProjectTo{Transpose})(dx::Tangent{<:Transpose}) # structural => natural
return dx.parent isa ArrayOrZero ? Transpose(project.parent(dx.parent)) : dx
end

# Diagonal
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
function (project::ProjectTo{Diagonal})(dx::AbstractArray)
ind = diagind(size(dx,1), size(dx,2), 0)
return Diagonal(project.diag(dx[ind]))
end
function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural
return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx
end

# Symmetric
for (SymHerm, chk, fun) in
Expand All @@ -429,80 +452,116 @@ for (SymHerm, chk, fun) in
sub = ProjectTo(parent(x))
# Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
sub isa ProjectTo{<:AbstractZero} && return sub
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub)
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), data=sub)
end
function (project::ProjectTo{$SymHerm})(dx::AbstractArray)
dy = project.parent(dx)
dy = project.data(dx)
# Here $chk means this is efficient on same-type.
# If we could mutate dx, then that could speed up action on dx::Matrix.
dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2
return $SymHerm(project.parent(dz), project.uplo)
return $SymHerm(project.data(dz), project.uplo)
end
function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm}) # structural => natural
return dx.data isa ArrayOrZero ? $SymHerm(project.data(dx.data), project.uplo) : dx
end
# This is an example of a subspace which is not a subtype,
# not clear how broadly it's worthwhile to try to support this.
function (project::ProjectTo{$SymHerm})(dx::Diagonal)
sub = project.parent # this is going to be unhappy about the size
sub_one = ProjectTo{project_type(sub)}(;
element=sub.element, axes=(sub.axes[1],)
)
return Diagonal(sub_one(dx.diag))
end
(project::ProjectTo{$SymHerm})(dx::Diagonal) = project.data(dx)
end
end

# Triangular
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg)
@eval begin
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
function (project::ProjectTo{$UL})(dx::Diagonal)
sub = project.parent
sub_one = ProjectTo{project_type(sub)}(;
element=sub.element, axes=(sub.axes[1],)
)
return Diagonal(sub_one(dx.diag))
Comment on lines -454 to -463
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To explain what's going on here:

  • First, this used to include UnitUpperTriangular for which it was wrong. The gradient of that has to be zero on the diagonal, not one. So that has moved to its own case, where it does UnitUpperTriangular(dx) .- I instead, which makes in fact an UpperTriangular, not a subtype.
  • Second, the handling of (project::ProjectTo{$UL})(dx::Diagonal) is much simplified. Instead of inventing the projector needed and handling the diagonal by hand, it exploits the fact that map(ProjectTo{Float32}, ::Diagonal) already knows what to do.

The second idea greatly simplifies many more exotic examples below, such as (project::ProjectTo{Tridiagonal})(dx::Bidiagonal) = project.full(dx).

ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x)))
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx))
function (project::ProjectTo{$UL})(dx::Tangent{<:$UL}) # structural => natural
return dx.data isa ArrayOrZero ? $UL(project.data(dx.data)) : dx
end
end
end
for UUL in (:UnitUpperTriangular, :UnitLowerTriangular)
UL = Symbol(string(UUL)[5:end])
@eval begin
ProjectTo(x::$UUL) = ProjectTo{$UUL}(; data=ProjectTo(parent(x)))
function (project::ProjectTo{$UUL})(dx::AbstractArray)
dy = project.data(dx)
# Since x's diagonal is fixed to 1, dx must be zero there:
return $UUL(dy) - I # makes an UpperTriangular, etc.
end
# No type perfectly encodes the gradient of UnitUpperTriangular.
# To avoid unnecessary copies of what projection produces,
# allow any UpperTriangular through:
(project::ProjectTo{$UUL})(dx::$UL) = project.data(dx)
end
end
# Subspaces which aren't subtypes, like Diagonal inside Symmetric above:
(project::ProjectTo{UpperTriangular})(dx::Diagonal) = project.data(dx)
(project::ProjectTo{LowerTriangular})(dx::Diagonal) = project.data(dx)

(project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx)
(project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx)

# Weird -- not exhaustive!
# one strategy is to recurse into the struct:
ProjectTo(x::Bidiagonal{T}) where {T<:Number} = generic_projector(x)
function (project::ProjectTo{Bidiagonal})(dx::AbstractMatrix)
uplo = LinearAlgebra.sym_uplo(project.uplo)
dv = project.dv(diag(dx))
ev = project.ev(uplo === :U ? diag(dx, 1) : diag(dx, -1))
return Bidiagonal(dv, ev, uplo)
(project::ProjectTo{UnitUpperTriangular})(dx::Diagonal) = NoTangent()
(project::ProjectTo{UnitLowerTriangular})(dx::Diagonal) = NoTangent()

# Multidiagonal
# For all of these, the eltypes must all match, so store one full-size projector for simplicity.
function ProjectTo(x::Bidiagonal)
full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x)
# full isa ProjectTo{<:AbstractZero} && return full # never happens, invoke misses the Bool method
ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) # better short-circuit
return ProjectTo{Bidiagonal}(; full = full, uplo = LinearAlgebra.sym_uplo(x.uplo))
end
function ProjectTo(x::Tridiagonal)
full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x)
ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false)
return ProjectTo{Tridiagonal}(; full = full)
end
function ProjectTo(x::SymTridiagonal)
full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x)
ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false)
return ProjectTo{SymTridiagonal}(; full = full)
end
# Own type: `project.full` can convert eltype mantaining strucure
function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal)
if project.uplo == dx.uplo
return generic_projection(project, dx) # fast path
else
uplo = LinearAlgebra.sym_uplo(project.uplo)
dv = project.dv(diag(dx))
ev = fill!(similar(dv, length(dv) - 1), 0)
return Bidiagonal(dv, ev, uplo)
if LinearAlgebra.sym_uplo(dx.uplo) == project.uplo
return project.full(dx)
else # make a dummy array, better type-stability than returning a Diagonal
return project.full(Bidiagonal(dx.dv, zero(dx.ev), project.uplo))
end
end

ProjectTo(x::SymTridiagonal{T}) where {T<:Number} = generic_projector(x)
function (project::ProjectTo{SymTridiagonal})(dx::AbstractMatrix)
dv = project.dv(diag(dx))
ev = project.ev((diag(dx, 1) .+ diag(dx, -1)) ./ 2)
(project::ProjectTo{Tridiagonal})(dx::Tridiagonal) = project.full(dx)
(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = project.full(dx)
# AbstractArray
(project::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), project.uplo)
(project::ProjectTo{Tridiagonal})(dx::AbstractArray) = Tridiagonal(project.full(dx))
(project::ProjectTo{SymTridiagonal})(dx::Symmetric) = SymTridiagonal(project.full(dx))
function (project::ProjectTo{SymTridiagonal})(dx::AbstractArray)
dz = project.full(dx)
dv = diag(dz)
ev = (diag(dz, 1) .+ diag(dz, -1)) ./ 2
return SymTridiagonal(dv, ev)
end
(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = generic_projection(project, dx)

# another strategy is just to use the AbstractArray method
function ProjectTo(x::Tridiagonal{T}) where {T<:Number}
notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x)
return ProjectTo{Tridiagonal}(; notparent=notparent)
# Subspaces which aren't subtypes:
(project::ProjectTo{Bidiagonal})(dx::Diagonal) = project.full(dx)
(project::ProjectTo{Tridiagonal})(dx::Diagonal) = project.full(dx)
(project::ProjectTo{Tridiagonal})(dx::Bidiagonal) = project.full(dx)
(project::ProjectTo{SymTridiagonal})(dx::Diagonal) = project.full(dx)
# structural => natural
function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal})
dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero || return dx
return project.full(Bidiagonal(dx.dv, dx.ev, project.uplo)) # will return a Diagonal when ev::AbstractZero
end
function (project::ProjectTo{Tridiagonal})(dx::AbstractArray)
dy = project.notparent(dx)
return Tridiagonal(dy)
function (project::ProjectTo{Tridiagonal})(dx::Tangent{<:Tridiagonal})
dx.dl isa ArrayOrZero && dx.d isa ArrayOrZero && dx.du isa ArrayOrZero || return dx
return project.full(Tridiagonal(dx.dl, dx.d, dx.du))
end
# Note that backing(::Tridiagonal) doesn't work, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/392
function (project::ProjectTo{SymTridiagonal})(dx::Tangent{<:SymTridiagonal})
dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero || return dx
return project.full(SymTridiagonal(dx.dv, dx.ev))
end


#####
##### `SparseArrays`
Expand Down Expand Up @@ -598,6 +657,6 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
m, n = size(dx)
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
else
invoke(project, Tuple{AbstractArray}, dx)
return invoke(project, Tuple{AbstractArray}, dx)
end
end
66 changes: 60 additions & 6 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Base.iterate(::AbstractZero, ::Any) = nothing
Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x)
Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T()

# Linear operators
Base.adjoint(z::AbstractZero) = z
Base.transpose(z::AbstractZero) = z
Base.:/(z::AbstractZero, ::Any) = z

Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
Expand All @@ -30,14 +27,71 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)

Base.getindex(z::AbstractZero, args...) = z

Base.getindex(z::AbstractZero, ind...) = z
Base.view(z::AbstractZero, ind...) = z
Base.sum(z::AbstractZero; dims=:) = z
Base.reshape(z::AbstractZero, size...) = z
Base.reverse(z::AbstractZero, args...; kwargs...) = z

(::Type{<:UniformScaling})(z::AbstractZero) = z
# LinearAlgebra
LinearAlgebra.adjoint(z::AbstractZero, ind...) = z
LinearAlgebra.transpose(z::AbstractZero, ind...) = z

for T in (
:UniformScaling, :Adjoint, :Transpose, :Diagonal,
:UpperTriangular, :LowerTriangular, :UpperHessenberg,
:UnitUpperTriangular, :UnitLowerTriangular,
)
VERSION < v"1.4" && T == :UpperHessenberg && continue # not defined in 1.0
@eval LinearAlgebra.$T(z::AbstractZero) = z
end

LinearAlgebra.Symmetric(z::AbstractZero, uplo=:U) = z
LinearAlgebra.Hermitian(z::AbstractZero, uplo=:U) = z

LinearAlgebra.Bidiagonal(dv::AbstractVector, ev::AbstractZero, uplo::Symbol) = Diagonal(dv)
function LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractVector, uplo::Symbol)
dv = fill!(similar(ev, length(ev) + 1), 0) # can't avoid making a dummy array
return Bidiagonal(dv, convert(typeof(dv), ev), uplo)
end
LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent()

# one Zero:
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractVector) = Bidiagonal(_promote_vectors(d, du)..., :U)
LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(_promote_vectors(d, dl)..., :L)
function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector)
d = fill!(similar(dl, length(dl) + 1), 0)
return Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du))
end
# two Zeros:
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d)
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractVector) = Bidiagonal(d, du, :U)
LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractZero) = Bidiagonal(d, dl, :L)
# three Zeros:
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractZero) = NoTangent()

LinearAlgebra.SymTridiagonal(dv::AbstractVector, ev::AbstractZero) = Diagonal(dv)
function LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractVector)
dv = fill!(similar(ev, length(ev) + 1), 0)
return SymTridiagonal(dv, convert(typeof(dv), ev))
end
LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractZero) = NoTangent()

# These types all demand exactly same-type vectors, but may get e.g. Fill, Vector.
_promote_vectors(x::T, y::T) where {T<:AbstractVector} = (x, y)
function _promote_vectors(x::AbstractVector, y::AbstractVector)
T = Base._return_type(+, Tuple{typeof(x), typeof(y)})
if isconcretetype(T)
return convert(T, x), convert(T, y)
else
if VERSION > v"1.4"
short = map(first ∘ promote, x, y)
else # on 1.0 and friends, neither map nor zip stop early. So we improvise
short = [promote(x[i], y[i])[1] for i in intersect(axes(x, 1), axes(y, 1))]
end
return convert(typeof(short), x), convert(typeof(short), y)
end
end

"""
ZeroTangent() <: AbstractZero
Expand Down
Loading