Skip to content

Commit

Permalink
add PermMatrixCSC (#78)
Browse files Browse the repository at this point in the history
* save

* update

* update

* fix tess

* bump version

* remove unused ci

* improve static performance of perm matrix

* improve test coverage

* improve test coverage

* fix inbounds

* fix tests

* update

* update
  • Loading branch information
GiggleLiu authored Apr 17, 2024
1 parent 857742b commit d55a7d7
Show file tree
Hide file tree
Showing 20 changed files with 470 additions and 285 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
name: CI
on:
- push
- pull_request
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -10,7 +14,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
2 changes: 2 additions & 0 deletions src/LuxurySparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ using SparseArrays: SparseMatrixCSC
using SparseArrays.HigherOrderFns
using Base: @propagate_inbounds
using LinearAlgebra
import SparseArrays: findnz, nnz
using LinearAlgebra: StructuredMatrixStyle
using Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!

# static types
export SDPermMatrix, SPermMatrix, PermMatrix, pmrand,
SDPermMatrixCSC, SPermMatrixCSC, PermMatrixCSC, pmcscrand,
SDSparseMatrixCSC, SSparseMatrixCSC, SparseMatrixCSC, sprand,
SparseMatrixCOO,
SDMatrix, SDVector,
Expand Down
123 changes: 83 additions & 40 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
abstract type AbstractPermMatrix{Tv, Ti} <: AbstractMatrix{Tv} end
"""
PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer}
PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti}
Expand All @@ -24,7 +25,7 @@ julia> PermMatrix([2,1,4,3], rand(4))
```
"""
struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractMatrix{Tv}
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

Expand All @@ -42,26 +43,74 @@ struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
new{Tv,Ti,Vv,Vi}(perm, vals)
end
end

function PermMatrix{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer}
PermMatrix{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals))
basetype(pm::PermMatrix) = PermMatrix
Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
M.perm[i] == j ? M.vals[i] : zero(Tv)
function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
@assert M.perm[i] == j "Can not set index due to the absense of entry: ($i, $j)"
@inbounds M.vals[i] = val
end

function PermMatrix(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
PermMatrix{Tv,Ti,Vv,Vi}(perm, vals)
# the column major version of `PermMatrix`
struct PermMatrixCSC{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

function PermMatrixCSC{Tv,Ti,Vv,Vi}(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
if length(perm) != length(vals)
throw(
DimensionMismatch(
"permutation ($(length(perm))) and multiply ($(length(vals))) length mismatch.",
),
)
end
new{Tv,Ti,Vv,Vi}(perm, vals)
end
end
basetype(pm::PermMatrixCSC) = PermMatrixCSC
@propagate_inbounds function Base.getindex(M::PermMatrixCSC{Tv}, i::Integer, j::Integer) where {Tv}
@boundscheck 0 < j <= size(M, 2)
@inbounds M.perm[j] == i ? M.vals[j] : zero(Tv)
end
function Base.setindex!(M::PermMatrixCSC, val, i::Integer, j::Integer)
@assert M.perm[j] == i "Can not set index due to the absense of entry: ($i, $j)"
@inbounds M.vals[j] = val
end

for MT in [:PermMatrix, :PermMatrixCSC]
@eval begin
function $MT{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer}
$MT{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals))
end

Base.:(==)(d1::PermMatrix, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.isapprox(d1::PermMatrix, d2::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...)
Base.zero(pm::PermMatrix) = PermMatrix(pm.perm, zero(pm.vals))
function $MT(
perm::Vi,
vals::Vv,
) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}}
$MT{Tv,Ti,Vv,Vi}(perm, vals)
end
end
end
Base.zero(pm::AbstractPermMatrix) = basetype(pm)(pm.perm, zero(pm.vals))
Base.similar(x::AbstractPermMatrix{Tv,Ti}) where {Tv,Ti} =
typeof(x)(copy(x.perm), similar(x.vals))
Base.similar(x::AbstractPermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
basetype(x){T,Ti}(copy(x.perm), similar(x.vals, T))

################# Comparison ##################
Base.:(==)(d1::AbstractPermMatrix, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.isapprox(d1::AbstractPermMatrix, d2::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...)
Base.copyto!(A::AbstractPermMatrix, B::AbstractPermMatrix) =
(copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A)

################# Array Functions ##################

Base.size(M::PermMatrix) = (length(M.perm), length(M.perm))
function Base.size(A::PermMatrix, d::Integer)
Base.size(M::AbstractPermMatrix) = (length(M.perm), length(M.perm))
function Base.size(A::AbstractPermMatrix, d::Integer)
if d < 1
throw(ArgumentError("dimension must be ≥ 1, got $d"))
elseif d <= 2
Expand All @@ -70,18 +119,6 @@ function Base.size(A::PermMatrix, d::Integer)
return 1
end
end
Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} =
M.perm[i] == j ? M.vals[i] : zero(Tv)
function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer)
if M.perm[i] == j
@inbounds M.vals[i] = val
else
throw(BoundsError(M, (i, j)))
end
end

Base.copyto!(A::PermMatrix, B::PermMatrix) =
(copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A)

"""
pmrand(T::Type, n::Int) -> PermMatrix
Expand All @@ -105,20 +142,26 @@ function pmrand end
pmrand(::Type{T}, n::Int) where {T} = PermMatrix(randperm(n), randn(T, n))
pmrand(n::Int) = pmrand(Float64, n)

Base.similar(x::PermMatrix{Tv,Ti}) where {Tv,Ti} =
PermMatrix{Tv,Ti}(copy(x.perm), similar(x.vals))
Base.similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} =
PermMatrix{T,Ti}(copy(x.perm), similar(x.vals, T))

# TODO: rewrite this
# function show(io::IO, M::PermMatrix)
# println("PermMatrix")
# for item in zip(M.perm, M.vals)
# i, p = item
# println("- ($i) * $p")
# end
# end
pmcscrand(::Type{T}, n::Int) where {T} = PermMatrixCSC(randperm(n), randn(T, n))
pmcscrand(n::Int) = pmcscrand(Float64, n)

Base.show(io::IO, ::MIME"text/plain", M::AbstractPermMatrix) = show(io, M)
function Base.show(io::IO, M::AbstractPermMatrix)
n = size(M, 1)
println(io, typeof(M))
nmax = 20
for (k, (i, j, p)) in enumerate(IterNz(M))
if k <= nmax || k > n-nmax
print(io, "($i, $j) = $p")
k < n && println(io)
elseif k == nmax+1
println(io, "...")
end
end
end
Base.hash(pm::AbstractPermMatrix) = hash((pm.perm, pm.vals))

######### sparse array interfaces #########
nnz(M::PermMatrix) = length(M.vals)
nnz(M::AbstractPermMatrix) = length(M.vals)
findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals)
findnz(M::PermMatrixCSC) = (M.perm, collect(1:size(M, 1)), M.vals)
80 changes: 24 additions & 56 deletions src/SSparseMatrixCSC.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,30 @@
@static if VERSION < v"1.4.0"
"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}
"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}
static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: AbstractSparseMatrix{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end

else
# NOTE: from 1.4.0, by subtyping AbstractSparseMatrixCSC, things like sparse broadcast
# should just work.

"""
SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti}
static version of SparseMatrixCSC
"""
struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <:
SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::SVector{NNZ,Ti} # Row values of nonzeros
nzval::SVector{NNZ,Tv} # Nonzero values

function SSparseMatrixCSC{Tv,Ti,NNZ,NP}(
m::Integer,
n::Integer,
colptr::SVector{NP,Ti},
rowval::SVector{NNZ,Ti},
nzval::SVector{NNZ,Tv},
) where {Tv,Ti<:Integer,NNZ,NP}
m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m"))
n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n"))
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval
end # @static
end
SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr
SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval

function SSparseMatrixCSC(
m::Integer,
Expand Down
47 changes: 24 additions & 23 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,33 @@ Base.imag(M::IMatrix{T}) where {T} = Diagonal(zeros(T, M.n))

# PermMatrix
for func in (:conj, :real, :imag)
@eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
@eval (Base.$func)(M::AbstractPermMatrix) = basetype(M)(M.perm, ($func)(M.vals))
end
Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))
Base.copy(M::AbstractPermMatrix) = basetype(M)(copy(M.perm), copy(M.vals))
Base.conj!(M::AbstractPermMatrix) = (conj!(M.vals); M)

function Base.transpose(M::PermMatrix)
function Base.transpose(M::AbstractPermMatrix)
new_perm = fast_invperm(M.perm)
return PermMatrix(new_perm, M.vals[new_perm])
return basetype(M)(new_perm, M.vals[new_perm])
end

Base.adjoint(S::PermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))
Base.adjoint(S::AbstractPermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::AbstractPermMatrix{<:Complex}) = conj!(transpose(S))

# scalar
Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:/(A::IMatrix{T}, B::Number) where {T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n))

Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
Base.:*(B::Number, A::PermMatrix) = A * B
Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
Base.:*(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals * B)
Base.:*(B::Number, A::AbstractPermMatrix) = A * B
Base.:/(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals / B)
#+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev)
#-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev)

for op in [:+, :-]
for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval begin
# IMatrix, PermMatrix - SparseMatrixCSC
Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
Expand All @@ -45,12 +46,12 @@ for op in [:+, :-]
# IMatrix, PermMatrix - Diagonal
Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2))
Base.$op(d1::AbstractPermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::AbstractPermMatrix) = $op(d1, SparseMatrixCSC(d2))
# PermMatrix - IMatrix
Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
end
end
# NOTE: promote to integer
Expand All @@ -59,22 +60,22 @@ Base.:+(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch())

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B
@eval Base.:(==)(A::SparseMatrixCSC, B::$MT) = A == SparseMatrixCSC(B)
end
Base.:(==)(d1::IMatrix, d2::Diagonal) = all(isone, d2.diag)
Base.:(==)(d1::Diagonal, d2::IMatrix) = all(isone, d1.diag)
Base.:(==)(d1::PermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::PermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::PermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(d1::AbstractPermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::AbstractPermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::AbstractPermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.isapprox(A::$MT, B::SparseMatrixCSC; kwargs...) = isapprox(SparseMatrixCSC(A), B)
@eval Base.isapprox(A::SparseMatrixCSC, B::$MT; kwargs...) = isapprox(A, SparseMatrixCSC(B))
@eval Base.isapprox(d1::$MT, d2::Diagonal; kwargs...) = isapprox(diag(d1), d2.diag)
@eval Base.isapprox(d1::Diagonal, d2::$MT; kwargs...) = isapprox(d1.diag, diag(d2))
end
Base.isapprox(A::IMatrix, B::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::PermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::IMatrix, B::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::AbstractPermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Loading

0 comments on commit d55a7d7

Please sign in to comment.