Skip to content

Commit

Permalink
Add constructors for SparseMatrixCSC from UniformScaling.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha0 committed Oct 28, 2017
1 parent d171d4d commit 8fd2325
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
47 changes: 26 additions & 21 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,13 @@ Note the difference from [`speye`](@ref).
spones(S::SparseMatrixCSC{T}) where {T} =
SparseMatrixCSC(S.m, S.n, copy(S.colptr), copy(S.rowval), ones(T, S.colptr[end]-1))

function one(S::SparseMatrixCSC{T}) where T
m,n = size(S)
if m != n; throw(DimensionMismatch("multiplicative identity only defined for square matrices")); end
speye(T, m)
end


"""
spzeros([type,]m[,n])
Expand Down Expand Up @@ -1529,29 +1536,27 @@ if not specified.
`sparse(α*I, m, n)` can be used to efficiently create a sparse
multiple `α` of the identity matrix.
"""
speye(::Type{T}, m::Integer, n::Integer) where {T} = speye_scaled(T, oneunit(T), m, n)

function one(S::SparseMatrixCSC{T}) where T
m,n = size(S)
if m != n; throw(DimensionMismatch("multiplicative identity only defined for square matrices")); end
speye(T, m)
end

speye_scaled(diag, m::Integer, n::Integer) = speye_scaled(typeof(diag), diag, m, n)

function speye_scaled(::Type{T}, diag, m::Integer, n::Integer) where T
((m < 0) || (n < 0)) && throw(ArgumentError("invalid array dimensions"))
if iszero(diag)
return SparseMatrixCSC(m, n, ones(Int, n+1), Vector{Int}(0), Vector{T}(0))
end
nnz = min(m,n)
colptr = Vector{Int}(1+n)
colptr[1:nnz+1] = 1:nnz+1
colptr[nnz+2:end] = nnz+1
SparseMatrixCSC(Int(m), Int(n), colptr, Vector{Int}(1:nnz), fill!(Vector{T}(nnz), diag))
speye(::Type{T}, m::Integer, n::Integer) where {T} = SparseMatrixCSC{T}(UniformScaling(one(T)), Dims((m, n)))
sparse(s::UniformScaling, m::Integer, n::Integer=m) = SparseMatrixCSC(s, Dims((m, n)))

## SparseMatrixCSC construction from UniformScaling
SparseMatrixCSC{Tv,Ti}(s::UniformScaling, m::Integer, n::Integer) where {Tv,Ti} = SparseMatrixCSC{Tv,Ti}(s, Dims((m, n)))
SparseMatrixCSC{Tv}(s::UniformScaling, m::Integer, n::Integer) where {Tv} = SparseMatrixCSC{Tv}(s, Dims((m, n)))
SparseMatrixCSC(s::UniformScaling, m::Integer, n::Integer) = SparseMatrixCSC(s, Dims((m, n)))
SparseMatrixCSC{Tv}(s::UniformScaling, dims::Dims{2}) where {Tv} = SparseMatrixCSC{Tv,Int}(s, dims)
SparseMatrixCSC(s::UniformScaling, dims::Dims{2}) = SparseMatrixCSC{eltype(s)}(s, dims)
function SparseMatrixCSC{Tv,Ti}(s::UniformScaling, dims::Dims{2}) where {Tv,Ti}
@boundscheck first(dims) < 0 && throw(ArgumentError("first dimension invalid ($(first(dims)) < 0)"))
@boundscheck last(dims) < 0 && throw(ArgumentError("second dimension invalid ($(last(dims)) < 0)"))
iszero(s.λ) && return spzeros(Tv, Ti, dims...)
m, n, k = dims..., min(dims...)
nzval = fill!(Vector{Tv}(k), Tv(s.λ))
rowval = copy!(Vector{Ti}(k), 1:k)
colptr = copy!(Vector{Ti}(n + 1), 1:(k + 1))
for i in (k + 2):(n + 1) colptr[i] = (k + 1) end
SparseMatrixCSC{Tv,Ti}(dims..., colptr, rowval, nzval)
end

sparse(S::UniformScaling, m::Integer, n::Integer=m) = speye_scaled(S.λ, m, n)

# TODO: More appropriate location?
conj!(A::SparseMatrixCSC) = (@inbounds broadcast!(conj, A.nzval, A.nzval); A)
Expand Down
13 changes: 13 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ end
@test sparse(Any[1,2,3], Any[1,2,3], Any[1,1,1], 5, 4) == sparse([1,2,3], [1,2,3], [1,1,1], 5, 4)
end

@testset "SparseMatrixCSC construction from UniformScaling" begin
@test_throws ArgumentError SparseMatrixCSC(I, -1, 3)
@test_throws ArgumentError SparseMatrixCSC(I, 3, -1)
@test SparseMatrixCSC(2I, 3, 3)::SparseMatrixCSC{Int,Int} == 2*eye(3)
@test SparseMatrixCSC(2I, 3, 4)::SparseMatrixCSC{Int,Int} == 2*eye(3, 4)
@test SparseMatrixCSC(2I, 4, 3)::SparseMatrixCSC{Int,Int} == 2*eye(4, 3)
@test SparseMatrixCSC(2.0I, 3, 3)::SparseMatrixCSC{Float64,Int} == 2*eye(3)
@test SparseMatrixCSC{Real}(2I, 3, 3)::SparseMatrixCSC{Real,Int} == 2*eye(3)
@test SparseMatrixCSC{Float64}(2I, 3, 3)::SparseMatrixCSC{Float64,Int} == 2*eye(3)
@test SparseMatrixCSC{Float64,Int32}(2I, 3, 3)::SparseMatrixCSC{Float64,Int32} == 2*eye(3)
@test SparseMatrixCSC{Float64,Int32}(0I, 3, 3)::SparseMatrixCSC{Float64,Int32} == spzeros(Float64, Int32, 3, 3)
end

se33 = speye(3)
do33 = ones(3)

Expand Down

0 comments on commit 8fd2325

Please sign in to comment.