Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PermMatrixCSC #78

Merged
merged 13 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve test coverage
  • Loading branch information
GiggleLiu committed Feb 26, 2024
commit 8fffba7fade02eb1652011095c5946fff4969428
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
StaticArrays = "1"
julia = "1"
julia = "1.8"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
6 changes: 3 additions & 3 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,6 @@ end
Base.hash(pm::AbstractPermMatrix) = hash((pm.perm, pm.vals))

######### sparse array interfaces #########
nnz(M::AbstractPermMatrix) = length(M.vals)
findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals)
findnz(M::PermMatrixCSC) = (M.perm, collect(1:size(M, 1)), M.vals)
SparseArrays.nnz(M::AbstractPermMatrix) = length(M.vals)
SparseArrays.findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals)
SparseArrays.findnz(M::PermMatrixCSC) = (M.perm, collect(1:size(M, 1)), M.vals)
80 changes: 24 additions & 56 deletions src/SSparseMatrixCSC.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,30 @@
@static if VERSION < v"1.4.0"
"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}

"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}
static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: AbstractSparseMatrix{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end

else
# NOTE: from 1.4.0, by subtyping AbstractSparseMatrixCSC, things like sparse broadcast
# should just work.

"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}

static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval
end # @static
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval

function SSparseMatrixCSC(
m::Integer,
Expand Down
5 changes: 0 additions & 5 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
@static if VERSION < v"1.2"
Base.size(bc::Broadcasted) = map(length, axes(bc))
Base.length(bc::Broadcasted) = prod(size(bc))
end

# patches
LinearAlgebra.fzero(S::IMatrix) = zero(eltype(S))

Expand Down
9 changes: 0 additions & 9 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ function SparseMatrixCSC(M::AbstractPermMatrix)
SparseMatrixCSC(n, n, collect(1:n+1), MC.perm, MC.vals)
end

@static if VERSION < v"1.3-"

function SparseMatrixCSC(D::Diagonal{T}) where {T}
m = length(D.diag)
return SparseMatrixCSC(m, m, Vector(1:(m+1)), Vector(1:m), Vector{T}(D.diag))
end

end

SparseMatrixCSC{Tv,Ti}(M::AbstractPermMatrix{Tv,Ti}) where {Tv,Ti} = SparseMatrixCSC(M)
SparseMatrixCSC(coo::SparseMatrixCOO) = sparse(coo.is, coo.js, coo.vs, coo.m, coo.n)

Expand Down
6 changes: 0 additions & 6 deletions test/IMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ end
end
@test imag(p1) == zeros(4, 4)
@test p1' == Matrix(I, 4, 4)

# This will be lazy evaluated in 0.7+
@static if VERSION < v"0.7-"
@test typeof(p1') == typeof(p1)
end

@test ishermitian(p1)
end

Expand Down
4 changes: 4 additions & 0 deletions test/PermMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ sp = sprand(4, 4, 0.3)
v = [0.5, 0.3im, 0.2, 1.0]

@testset "basic" begin
@test_throws DimensionMismatch PermMatrix([1, 4, 2, 3], [0.1, 0.2, 0.4im])
@test_throws ArgumentError size(p1, 0)
@test size(p1, 3) == 1
@test [zip(findnz(p1)...)...] == [IterNz(p1)...]
@test p1 == copy(p1)
@test hash(p1) == hash(copy(p1))
@test hash(p1) != hash(p2)
Expand Down
4 changes: 4 additions & 0 deletions test/PermMatrixCSC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ sp = sprand(4, 4, 0.3)
v = [0.5, 0.3im, 0.2, 1.0]

@testset "basic" begin
@test_throws DimensionMismatch PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im])
@test_throws ArgumentError size(p1, 0)
@test size(p1, 3) == 1
@test [zip(findnz(p1)...)...] == [IterNz(p1)...]
@test p1 == copy(p1)
@test hash(p1) == hash(copy(p1))
@test hash(p1) != hash(p2)
Expand Down
17 changes: 17 additions & 0 deletions test/staticize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using StaticArrays: SVector, SMatrix
Random.seed!(2)

@testset "staticize" begin
@test staticize(1) == 1
# permmatrix
m = pmrand(ComplexF64, 4)
sm = m |> staticize
Expand All @@ -24,6 +25,22 @@ Random.seed!(2)
@test dm.perm == m.perm
@test dm.vals == m.vals

# permmatrixcsc
m = pmcscrand(ComplexF64, 4)
println(m)
sm = m |> staticize
@test sm isa SPermMatrixCSC{4,ComplexF64}
@test sm.perm isa SVector
@test sm.vals isa SVector
@test sm.perm == m.perm
@test sm.vals == m.vals
dm = sm |> dynamicize
@test dm isa PermMatrixCSC{ComplexF64}
@test dm.perm isa Vector
@test dm.vals isa Vector
@test dm.perm == m.perm
@test dm.vals == m.vals

# csc
m = sprand(ComplexF64, 4, 4, 0.5)
sm = m |> staticize
Expand Down
Loading