diff --git a/src/projection.jl b/src/projection.jl index 17880116f..01ee9d93b 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -435,6 +435,10 @@ end ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) +function (project::ProjectTo{Diagonal})(dx::AbstractArray) + ind = diagind(size(dx,1), size(dx,2), 0) + return Diagonal(project.diag(dx[ind])) +end function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx end diff --git a/test/projection.jl b/test/projection.jl index 55e9e4064..88be8f8d0 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -307,6 +307,7 @@ struct NoSuperType end @testset "LinearAlgebra: sparse structured matrices" begin pdiag = ProjectTo(Diagonal(1:3)) @test pdiag(reshape(1:9, 3, 3)) == Diagonal([1, 5, 9]) + @test pdiag(reshape(1:9, 3, 3, 1)) == Diagonal([1, 5, 9]) @test pdiag(pdiag(reshape(1:9, 3, 3))) == pdiag(reshape(1:9, 3, 3)) @test pdiag(rand(ComplexF32, 3, 3)) isa Diagonal{Float64} @test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0)