Skip to content

Commit 4e1bfff

Browse files
carstenbauerjohanmon
authored andcommitted
Conversion methods sparse matrix -> special linalg type (JuliaLang#40988)
1 parent a89ed9b commit 4e1bfff

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

stdlib/SparseArrays/src/SparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using LinearAlgebra
1212

1313
import Base: +, -, *, \, /, &, |, xor, ==, zero
1414
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
15-
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!,
15+
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
1616
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu, matprod
1717

1818
import Base: acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,

stdlib/SparseArrays/src/sparsematrix.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,17 @@ Array(S::AbstractSparseMatrixCSC) = Matrix(S)
683683

684684
convert(T::Type{<:AbstractSparseMatrixCSC}, m::AbstractMatrix) = m isa T ? m : T(m)
685685

686+
convert(T::Type{<:Diagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
687+
isdiag(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as Diagonal"))
688+
convert(T::Type{<:SymTridiagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
689+
issymmetric(m) && isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
690+
convert(T::Type{<:Tridiagonal}, m::AbstractSparseMatrixCSC) = m isa T ? m :
691+
isbanded(m, -1, 1) ? T(m) : throw(ArgumentError("matrix cannot be represented as Tridiagonal"))
692+
convert(T::Type{<:LowerTriangular}, m::AbstractSparseMatrixCSC) = m isa T ? m :
693+
istril(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as LowerTriangular"))
694+
convert(T::Type{<:UpperTriangular}, m::AbstractSparseMatrixCSC) = m isa T ? m :
695+
istriu(m) ? T(m) : throw(ArgumentError("matrix cannot be represented as UpperTriangular"))
696+
686697
float(S::SparseMatrixCSC) = SparseMatrixCSC(size(S, 1), size(S, 2), copy(getcolptr(S)), copy(rowvals(S)), float.(nonzeros(S)))
687698
complex(S::SparseMatrixCSC) = SparseMatrixCSC(size(S, 1), size(S, 2), copy(getcolptr(S)), copy(rowvals(S)), complex(copy(nonzeros(S))))
688699

stdlib/SparseArrays/test/sparse.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,24 @@ end
7878
@test Array(SparseMatrixCSC{eltype(a), Int8}(a)) == Array(a)
7979
end
8080

81+
@testset "conversion to special LinearAlgebra types" begin
82+
# issue 40924
83+
@test convert(Diagonal, sparse(Diagonal(1:2))) isa Diagonal
84+
@test convert(Diagonal, sparse(Diagonal(1:2))) == Diagonal(1:2)
85+
@test convert(Tridiagonal, sparse(Tridiagonal(1:3, 4:7, 8:10))) isa Tridiagonal
86+
@test convert(Tridiagonal, sparse(Tridiagonal(1:3, 4:7, 8:10))) == Tridiagonal(1:3, 4:7, 8:10)
87+
@test convert(SymTridiagonal, sparse(SymTridiagonal(1:4, 5:7))) isa SymTridiagonal
88+
@test convert(SymTridiagonal, sparse(SymTridiagonal(1:4, 5:7))) == SymTridiagonal(1:4, 5:7)
89+
90+
lt = LowerTriangular([1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0])
91+
@test convert(LowerTriangular, sparse(lt)) isa LowerTriangular
92+
@test convert(LowerTriangular, sparse(lt)) == lt
93+
94+
ut = UpperTriangular([1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0])
95+
@test convert(UpperTriangular, sparse(ut)) isa UpperTriangular
96+
@test convert(UpperTriangular, sparse(ut)) == ut
97+
end
98+
8199
@testset "sparse matrix construction" begin
82100
@test (A = fill(1.0+im,5,5); isequal(Array(sparse(A)), A))
83101
@test_throws ArgumentError sparse([1,2,3], [1,2], [1,2,3], 3, 3)

0 commit comments

Comments
 (0)