Skip to content

Commit

Permalink
upgrade, more matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 15, 2021
1 parent cf15bca commit e67f2b0
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
31 changes: 28 additions & 3 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
return adjoint(project.parent(dy))
end
# structural => natural standardisation, broadest possible signature
function (project::ProjectTo{Adjoint})(dx::Tangent)
if dx.parent isa Tangent
# Can't wrap a structural representation of an array in an Adjoint:
return dx
else
# This case should handle dx.parent isa AbstractZero, too
return Adjoint(project.parent(dx.parent))
end
end

function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
return ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
Expand All @@ -320,14 +330,22 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
return transpose(project.parent(dy))
end
function (project::ProjectTo{Transpose})(
dx::Tangent{<:Transpose, <:NamedTuple{(:parent,), <:Tuple{AbstractVector}}},
)
return Transpose(project.parent(dx.parent))
end

# Diagonal
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))

(project::ProjectTo{Diagonal})(dx::Tangent{T}) where T = (@show T; Diagonal(project.diag(dx.diag)))
# (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal, NamedTuple{(:diag,), <:Tuple{AbstractVector}}}) = Diagonal(project.diag(@show dx.diag))
# structural => natural standardisation, very conservative signature:
function (project::ProjectTo{Diagonal})(
dx::Tangent{<:Diagonal, <:NamedTuple{(:diag,), <:Tuple{AbstractVector}}},
)
return Diagonal(project.diag(dx.diag))
end

# Symmetric
for (SymHerm, chk, fun) in
Expand All @@ -346,6 +364,13 @@ for (SymHerm, chk, fun) in
dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2
return $SymHerm(project.parent(dz), project.uplo)
end
function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm})
if dx.data isa Tangent
return dx
else
return $SymHerm(project.parent(dx.data))
end
end
# This is an example of a subspace which is not a subtype,
# not clear how broadly it's worthwhile to try to support this.
function (project::ProjectTo{$SymHerm})(dx::Diagonal)
Expand Down
11 changes: 8 additions & 3 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Base.iterate(::AbstractZero, ::Any) = nothing
Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x)
Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T()

# Linear operators
Base.adjoint(z::AbstractZero) = z
Base.transpose(z::AbstractZero) = z
Base.:/(z::AbstractZero, ::Any) = z

Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
Expand All @@ -35,6 +32,14 @@ Base.getindex(z::AbstractZero, k) = z
Base.view(z::AbstractZero, ind...) = z
Base.sum(z::AbstractZero; dims=:) = z

# LinearAlgebra
for f in (:adjoint, :transpose, :Adjoint, :Transpose, :Diagonal)
@eval LinearAlgebra.$f(z::AbstractZero) = z
end
for f in (:Symmetric, :Hermitian)
@eval LinearAlgebra.$f(z::AbstractZero, uplo=:U) = z
end

"""
ZeroTangent() <: AbstractZero
Expand Down
10 changes: 10 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ struct NoSuperType end
@test padj_complex(transpose([4, 5, 6 + 7im])) == [4 5 6 + 7im]
@test padj_complex(adjoint([4, 5, 6 + 7im])) == [4 5 6 - 7im]

# structural => natural
@test padj(Tangent{adjT}(; parent=ones(3) .+ im)) isa adjT
@test_skip padj(Tangent{Any}(; parent=ones(3))) isa adjT # only for Adjoint now

# evil test case
if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called
xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]])
Expand Down Expand Up @@ -204,6 +208,10 @@ struct NoSuperType end
@test psymm(psymm(reshape(1:9, 3, 3))) == psymm(reshape(1:9, 3, 3))
@test psymm(rand(ComplexF32, 3, 3, 1)) isa Symmetric{Float64}
@test ProjectTo(Symmetric(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() # Bool
# structural => natural
dx = Tangent{typeof(Symmetric(rand(3, 3)))}(; data=[1 2 3; 4 5 6; 7 8 9im])
@test psymm(dx) isa Symmetric{Float64}
@test psymm(Tangent{typeof(Symmetric(rand(3, 3)))}(; )) isa AbstractZero

pherm = ProjectTo(Hermitian(rand(3, 3) .+ im, :L))
# NB, projection onto Hermitian subspace, not application of Hermitian constructor
Expand All @@ -230,6 +238,8 @@ struct NoSuperType end
@test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0)
@test ProjectTo(Diagonal(randn(3) .> 0))(randn(3, 3)) == NoTangent()
@test ProjectTo(Diagonal(randn(3) .> 0))(Diagonal(rand(3))) == NoTangent()
# structural => natural
@test pdiag(Tangent{typeof(Diagonal(1:3))}(; diag=ones(3) .+ im)) isa Diagonal{Float64}

pbi = ProjectTo(Bidiagonal(rand(3, 3), :L))
@test pbi(reshape(1:9, 3, 3)) == [1.0 0.0 0.0; 2.0 5.0 0.0; 0.0 6.0 9.0]
Expand Down
9 changes: 9 additions & 0 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,13 @@
@test convert(Int64, NoTangent()) == 0
@test convert(Float64, NoTangent()) == 0.0
end

@testset "LinearAlgebra constructors" begin
@test adjoint(ZeroTangent()) === ZeroTangent()
@test transpose(ZeroTangent()) === ZeroTangent()
@test Adjoint(ZeroTangent()) === ZeroTangent()
@test Transpose(ZeroTangent()) === ZeroTangent()
@test Symmetric(ZeroTangent()) === ZeroTangent()
@test Hermitian(ZeroTangent(), :U) === ZeroTangent()
end
end

0 comments on commit e67f2b0

Please sign in to comment.