From d55a7d77455870967dc2cb7e94824ba9fb4dc41a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jinguo=20Liu=20=28=E5=88=98=E9=87=91=E5=9B=BD=29?= Date: Wed, 17 Apr 2024 09:59:07 +0800 Subject: [PATCH] add PermMatrixCSC (#78) * save * update * update * fix tess * bump version * remove unused ci * improve static performance of perm matrix * improve test coverage * improve test coverage * fix inbounds * fix tests * update * update --- .github/workflows/CI.yml | 9 ++- src/LuxurySparse.jl | 2 + src/PermMatrix.jl | 123 ++++++++++++++++++++++++++------------- src/SSparseMatrixCSC.jl | 80 ++++++++----------------- src/arraymath.jl | 47 +++++++-------- src/broadcast.jl | 49 +++++----------- src/conversions.jl | 84 ++++++++++++++------------ src/iterate.jl | 13 ++++- src/kronecker.jl | 55 ++++++++--------- src/linalg.jl | 97 +++++++++++++++--------------- src/promotions.jl | 14 +++-- src/staticize.jl | 5 ++ test/IMatrix.jl | 6 -- test/PermMatrix.jl | 11 +++- test/PermMatrixCSC.jl | 114 ++++++++++++++++++++++++++++++++++++ test/broadcast.jl | 16 ++++- test/iterate.jl | 1 + test/kronecker.jl | 11 ++-- test/runtests.jl | 1 + test/staticize.jl | 17 ++++++ 20 files changed, 470 insertions(+), 285 deletions(-) create mode 100644 test/PermMatrixCSC.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 759083f..161bb67 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,7 +1,11 @@ name: CI on: - - push - - pull_request + push: + branches: + - master + pull_request: + branches: + - master jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -10,7 +14,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1' - 'nightly' os: diff --git a/src/LuxurySparse.jl b/src/LuxurySparse.jl index c84a24d..d4afcf3 100644 --- a/src/LuxurySparse.jl +++ b/src/LuxurySparse.jl @@ -6,12 +6,14 @@ using SparseArrays: SparseMatrixCSC using SparseArrays.HigherOrderFns using Base: @propagate_inbounds using LinearAlgebra +import SparseArrays: findnz, nnz using LinearAlgebra: StructuredMatrixStyle using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize! # static types export SDPermMatrix, SPermMatrix, PermMatrix, pmrand, + SDPermMatrixCSC, SPermMatrixCSC, PermMatrixCSC, pmcscrand, SDSparseMatrixCSC, SSparseMatrixCSC, SparseMatrixCSC, sprand, SparseMatrixCOO, SDMatrix, SDVector, diff --git a/src/PermMatrix.jl b/src/PermMatrix.jl index f5d1d32..732a13b 100644 --- a/src/PermMatrix.jl +++ b/src/PermMatrix.jl @@ -1,3 +1,4 @@ +abstract type AbstractPermMatrix{Tv, Ti} <: AbstractMatrix{Tv} end """ PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer} PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti} @@ -24,7 +25,7 @@ julia> PermMatrix([2,1,4,3], rand(4)) ``` """ struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <: - AbstractMatrix{Tv} + AbstractPermMatrix{Tv,Ti} perm::Vi # new orders vals::Vv # multiplied values. @@ -42,26 +43,74 @@ struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} new{Tv,Ti,Vv,Vi}(perm, vals) end end - -function PermMatrix{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer} - PermMatrix{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals)) +basetype(pm::PermMatrix) = PermMatrix +Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} = + M.perm[i] == j ? M.vals[i] : zero(Tv) +function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer) + @assert M.perm[i] == j "Can not set index due to the absense of entry: ($i, $j)" + @inbounds M.vals[i] = val end -function PermMatrix( - perm::Vi, - vals::Vv, -) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} - PermMatrix{Tv,Ti,Vv,Vi}(perm, vals) +# the column major version of `PermMatrix` +struct PermMatrixCSC{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <: + AbstractPermMatrix{Tv,Ti} + perm::Vi # new orders + vals::Vv # multiplied values. + + function PermMatrixCSC{Tv,Ti,Vv,Vi}( + perm::Vi, + vals::Vv, + ) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} + if length(perm) != length(vals) + throw( + DimensionMismatch( + "permutation ($(length(perm))) and multiply ($(length(vals))) length mismatch.", + ), + ) + end + new{Tv,Ti,Vv,Vi}(perm, vals) + end +end +basetype(pm::PermMatrixCSC) = PermMatrixCSC +@propagate_inbounds function Base.getindex(M::PermMatrixCSC{Tv}, i::Integer, j::Integer) where {Tv} + @boundscheck 0 < j <= size(M, 2) + @inbounds M.perm[j] == i ? M.vals[j] : zero(Tv) end +function Base.setindex!(M::PermMatrixCSC, val, i::Integer, j::Integer) + @assert M.perm[j] == i "Can not set index due to the absense of entry: ($i, $j)" + @inbounds M.vals[j] = val +end + +for MT in [:PermMatrix, :PermMatrixCSC] + @eval begin + function $MT{Tv,Ti}(perm, vals) where {Tv,Ti<:Integer} + $MT{Tv,Ti,Vector{Tv},Vector{Ti}}(Vector{Ti}(perm), Vector{Tv}(vals)) + end -Base.:(==)(d1::PermMatrix, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) -Base.isapprox(d1::PermMatrix, d2::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...) -Base.zero(pm::PermMatrix) = PermMatrix(pm.perm, zero(pm.vals)) + function $MT( + perm::Vi, + vals::Vv, + ) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} + $MT{Tv,Ti,Vv,Vi}(perm, vals) + end + end +end +Base.zero(pm::AbstractPermMatrix) = basetype(pm)(pm.perm, zero(pm.vals)) +Base.similar(x::AbstractPermMatrix{Tv,Ti}) where {Tv,Ti} = + typeof(x)(copy(x.perm), similar(x.vals)) +Base.similar(x::AbstractPermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} = + basetype(x){T,Ti}(copy(x.perm), similar(x.vals, T)) + +################# Comparison ################## +Base.:(==)(d1::AbstractPermMatrix, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) +Base.isapprox(d1::AbstractPermMatrix, d2::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(d1), SparseMatrixCSC(d2); kwargs...) +Base.copyto!(A::AbstractPermMatrix, B::AbstractPermMatrix) = + (copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A) ################# Array Functions ################## -Base.size(M::PermMatrix) = (length(M.perm), length(M.perm)) -function Base.size(A::PermMatrix, d::Integer) +Base.size(M::AbstractPermMatrix) = (length(M.perm), length(M.perm)) +function Base.size(A::AbstractPermMatrix, d::Integer) if d < 1 throw(ArgumentError("dimension must be ≥ 1, got $d")) elseif d <= 2 @@ -70,18 +119,6 @@ function Base.size(A::PermMatrix, d::Integer) return 1 end end -Base.getindex(M::PermMatrix{Tv}, i::Integer, j::Integer) where {Tv} = - M.perm[i] == j ? M.vals[i] : zero(Tv) -function Base.setindex!(M::PermMatrix, val, i::Integer, j::Integer) - if M.perm[i] == j - @inbounds M.vals[i] = val - else - throw(BoundsError(M, (i, j))) - end -end - -Base.copyto!(A::PermMatrix, B::PermMatrix) = - (copyto!(A.perm, B.perm); copyto!(A.vals, B.vals); A) """ pmrand(T::Type, n::Int) -> PermMatrix @@ -105,20 +142,26 @@ function pmrand end pmrand(::Type{T}, n::Int) where {T} = PermMatrix(randperm(n), randn(T, n)) pmrand(n::Int) = pmrand(Float64, n) -Base.similar(x::PermMatrix{Tv,Ti}) where {Tv,Ti} = - PermMatrix{Tv,Ti}(copy(x.perm), similar(x.vals)) -Base.similar(x::PermMatrix{Tv,Ti}, ::Type{T}) where {Tv,Ti,T} = - PermMatrix{T,Ti}(copy(x.perm), similar(x.vals, T)) - -# TODO: rewrite this -# function show(io::IO, M::PermMatrix) -# println("PermMatrix") -# for item in zip(M.perm, M.vals) -# i, p = item -# println("- ($i) * $p") -# end -# end +pmcscrand(::Type{T}, n::Int) where {T} = PermMatrixCSC(randperm(n), randn(T, n)) +pmcscrand(n::Int) = pmcscrand(Float64, n) + +Base.show(io::IO, ::MIME"text/plain", M::AbstractPermMatrix) = show(io, M) +function Base.show(io::IO, M::AbstractPermMatrix) + n = size(M, 1) + println(io, typeof(M)) + nmax = 20 + for (k, (i, j, p)) in enumerate(IterNz(M)) + if k <= nmax || k > n-nmax + print(io, "($i, $j) = $p") + k < n && println(io) + elseif k == nmax+1 + println(io, "...") + end + end +end +Base.hash(pm::AbstractPermMatrix) = hash((pm.perm, pm.vals)) ######### sparse array interfaces ######### -nnz(M::PermMatrix) = length(M.vals) +nnz(M::AbstractPermMatrix) = length(M.vals) findnz(M::PermMatrix) = (collect(1:size(M, 1)), M.perm, M.vals) +findnz(M::PermMatrixCSC) = (M.perm, collect(1:size(M, 1)), M.vals) diff --git a/src/SSparseMatrixCSC.jl b/src/SSparseMatrixCSC.jl index 0226ae2..5f6658d 100644 --- a/src/SSparseMatrixCSC.jl +++ b/src/SSparseMatrixCSC.jl @@ -1,62 +1,30 @@ -@static if VERSION < v"1.4.0" +""" + SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti} - """ - SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti} +static version of SparseMatrixCSC +""" +struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: + SparseArrays.AbstractSparseMatrixCSC{Tv,Ti} + m::Int # Number of rows + n::Int # Number of columns + colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1) + rowval::SVector{NNZ,Ti} # Row values of nonzeros + nzval::SVector{NNZ,Tv} # Nonzero values - static version of SparseMatrixCSC - """ - struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: AbstractSparseMatrix{Tv,Ti} - m::Int # Number of rows - n::Int # Number of columns - colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1) - rowval::SVector{NNZ,Ti} # Row values of nonzeros - nzval::SVector{NNZ,Tv} # Nonzero values - - function SSparseMatrixCSC{Tv,Ti,NNZ,NP}( - m::Integer, - n::Integer, - colptr::SVector{NP,Ti}, - rowval::SVector{NNZ,Ti}, - nzval::SVector{NNZ,Tv}, - ) where {Tv,Ti<:Integer,NNZ,NP} - m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m")) - n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n")) - new(Int(m), Int(n), colptr, rowval, nzval) - end + function SSparseMatrixCSC{Tv,Ti,NNZ,NP}( + m::Integer, + n::Integer, + colptr::SVector{NP,Ti}, + rowval::SVector{NNZ,Ti}, + nzval::SVector{NNZ,Tv}, + ) where {Tv,Ti<:Integer,NNZ,NP} + m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m")) + n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n")) + new(Int(m), Int(n), colptr, rowval, nzval) end - -else - # NOTE: from 1.4.0, by subtyping AbstractSparseMatrixCSC, things like sparse broadcast - # should just work. - - """ - SSparseMatrixCSC{Tv,Ti<:Integer, NNZ, NP} <: AbstractSparseMatrix{Tv,Ti} - - static version of SparseMatrixCSC - """ - struct SSparseMatrixCSC{Tv,Ti<:Integer,NNZ,NP} <: - SparseArrays.AbstractSparseMatrixCSC{Tv,Ti} - m::Int # Number of rows - n::Int # Number of columns - colptr::SVector{NP,Ti} # Column i is in colptr[i]:(colptr[i+1]-1) - rowval::SVector{NNZ,Ti} # Row values of nonzeros - nzval::SVector{NNZ,Tv} # Nonzero values - - function SSparseMatrixCSC{Tv,Ti,NNZ,NP}( - m::Integer, - n::Integer, - colptr::SVector{NP,Ti}, - rowval::SVector{NNZ,Ti}, - nzval::SVector{NNZ,Tv}, - ) where {Tv,Ti<:Integer,NNZ,NP} - m < 0 && throw(ArgumentError("number of rows (m) must be ≥ 0, got $m")) - n < 0 && throw(ArgumentError("number of columns (n) must be ≥ 0, got $n")) - new(Int(m), Int(n), colptr, rowval, nzval) - end - end - SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr - SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval -end # @static +end +SparseArrays.getcolptr(M::SSparseMatrixCSC) = M.colptr +SparseArrays.rowvals(M::SSparseMatrixCSC) = M.rowval function SSparseMatrixCSC( m::Integer, diff --git a/src/arraymath.jl b/src/arraymath.jl index d26045d..471eb2f 100644 --- a/src/arraymath.jl +++ b/src/arraymath.jl @@ -9,17 +9,18 @@ Base.imag(M::IMatrix{T}) where {T} = Diagonal(zeros(T, M.n)) # PermMatrix for func in (:conj, :real, :imag) - @eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals)) + @eval (Base.$func)(M::AbstractPermMatrix) = basetype(M)(M.perm, ($func)(M.vals)) end -Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals)) +Base.copy(M::AbstractPermMatrix) = basetype(M)(copy(M.perm), copy(M.vals)) +Base.conj!(M::AbstractPermMatrix) = (conj!(M.vals); M) -function Base.transpose(M::PermMatrix) +function Base.transpose(M::AbstractPermMatrix) new_perm = fast_invperm(M.perm) - return PermMatrix(new_perm, M.vals[new_perm]) + return basetype(M)(new_perm, M.vals[new_perm]) end -Base.adjoint(S::PermMatrix{<:Real}) = transpose(S) -Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S)) +Base.adjoint(S::AbstractPermMatrix{<:Real}) = transpose(S) +Base.adjoint(S::AbstractPermMatrix{<:Complex}) = conj!(transpose(S)) # scalar Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n)) @@ -27,14 +28,14 @@ Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, elty Base.:/(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n)) -Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B) -Base.:*(B::Number, A::PermMatrix) = A * B -Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B) +Base.:*(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals * B) +Base.:*(B::Number, A::AbstractPermMatrix) = A * B +Base.:/(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals / B) #+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev) #-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev) for op in [:+, :-] - for MT in [:IMatrix, :PermMatrix] + for MT in [:IMatrix, :AbstractPermMatrix] @eval begin # IMatrix, PermMatrix - SparseMatrixCSC Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B) @@ -45,12 +46,12 @@ for op in [:+, :-] # IMatrix, PermMatrix - Diagonal Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag)) Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2))) - Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2) - Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2)) + Base.$op(d1::AbstractPermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2) + Base.$op(d1::Diagonal, d2::AbstractPermMatrix) = $op(d1, SparseMatrixCSC(d2)) # PermMatrix - IMatrix - Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) - Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) - Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::AbstractPermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::IMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::AbstractPermMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) end end # NOTE: promote to integer @@ -59,22 +60,22 @@ Base.:+(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} = Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} = d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch()) -for MT in [:IMatrix, :PermMatrix] +for MT in [:IMatrix, :AbstractPermMatrix] @eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B @eval Base.:(==)(A::SparseMatrixCSC, B::$MT) = A == SparseMatrixCSC(B) end Base.:(==)(d1::IMatrix, d2::Diagonal) = all(isone, d2.diag) Base.:(==)(d1::Diagonal, d2::IMatrix) = all(isone, d1.diag) -Base.:(==)(d1::PermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) -Base.:(==)(d1::Diagonal, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) -Base.:(==)(A::IMatrix, B::PermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) -Base.:(==)(A::PermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) +Base.:(==)(d1::AbstractPermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) +Base.:(==)(d1::Diagonal, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) +Base.:(==)(A::IMatrix, B::AbstractPermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) +Base.:(==)(A::AbstractPermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) -for MT in [:IMatrix, :PermMatrix] +for MT in [:IMatrix, :AbstractPermMatrix] @eval Base.isapprox(A::$MT, B::SparseMatrixCSC; kwargs...) = isapprox(SparseMatrixCSC(A), B) @eval Base.isapprox(A::SparseMatrixCSC, B::$MT; kwargs...) = isapprox(A, SparseMatrixCSC(B)) @eval Base.isapprox(d1::$MT, d2::Diagonal; kwargs...) = isapprox(diag(d1), d2.diag) @eval Base.isapprox(d1::Diagonal, d2::$MT; kwargs...) = isapprox(d1.diag, diag(d2)) end -Base.isapprox(A::IMatrix, B::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) -Base.isapprox(A::PermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) +Base.isapprox(A::IMatrix, B::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) +Base.isapprox(A::AbstractPermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) diff --git a/src/broadcast.jl b/src/broadcast.jl index 77c5c0a..faccf73 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -1,8 +1,3 @@ -@static if VERSION < v"1.2" - Base.size(bc::Broadcasted) = map(length, axes(bc)) - Base.length(bc::Broadcasted) = prod(size(bc)) -end - # patches LinearAlgebra.fzero(S::IMatrix) = zero(eltype(S)) @@ -42,12 +37,10 @@ Broadcast.broadcasted( ) = Diagonal(fill(a, b.n)) # specialize perm matrix -function _broadcast_perm_prod(A::PermMatrix, B::AbstractMatrix) +function _broadcast_perm_prod(A::AbstractPermMatrix, B::AbstractMatrix) dest = similar(A, Base.promote_op(*, eltype(A), eltype(B))) - i = 1 - @inbounds for j in dest.perm - dest[i, j] = A[i, j] * B[i, j] - i += 1 + @inbounds for (i, j, a) in IterNz(A) + dest[i, j] = a * B[i, j] end return dest end @@ -55,40 +48,30 @@ end Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), - A::PermMatrix, + A::AbstractPermMatrix, B::AbstractMatrix, ) = _broadcast_perm_prod(A, B) Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), A::AbstractMatrix, - B::PermMatrix, + B::AbstractPermMatrix, ) = _broadcast_perm_prod(B, A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::AbstractPermMatrix) = _broadcast_perm_prod(A, B) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::IMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::IMatrix) = Diagonal(A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::AbstractPermMatrix) = Diagonal(B) -function _broadcast_diag_perm_prod(A::Diagonal, B::PermMatrix) - dest = similar(A) - i = 1 - @inbounds for j in B.perm - if i == j - dest[i, i] = A[i, i] * B[i, i] - else - dest[i, i] = 0 - end - i += 1 - end - return dest +function _broadcast_diag_perm_prod(A::Diagonal, B::AbstractPermMatrix) + Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1), 1:size(A, 2))) end -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::Diagonal) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::Diagonal) = _broadcast_diag_perm_prod(B, A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::AbstractPermMatrix) = _broadcast_diag_perm_prod(A, B) # TODO: commit this upstream @@ -110,13 +93,13 @@ Broadcast.broadcasted( Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), - a::PermMatrix, + a::AbstractPermMatrix, b::Number, -) = PermMatrix(a.perm, a.vals .* b) +) = basetype(a)(a.perm, a.vals .* b) Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), a::Number, - b::PermMatrix, -) = PermMatrix(b.perm, a .* b.vals) + b::AbstractPermMatrix, +) = basetype(b)(b.perm, a .* b.vals) diff --git a/src/conversions.jl b/src/conversions.jl index 80d2630..e01b0f5 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -7,7 +7,7 @@ function IMatrix(A::AbstractMatrix{T}) where {T} end ################## To Diagonal ###################### -Diagonal(A::PermMatrix) = Diagonal(diag(A)) +Diagonal(A::AbstractPermMatrix) = Diagonal(diag(A)) Diagonal(A::IMatrix{T}) where {T} = Diagonal{T}(ones(T, A.n)) Diagonal{T}(A::IMatrix) where {T} = Diagonal{T}(ones(T, A.n)) @@ -16,23 +16,13 @@ SparseMatrixCSC{Tv,Ti}(A::IMatrix) where {Tv,Ti<:Integer} = SparseMatrixCSC{Tv,Ti}(I, A.n, A.n) SparseMatrixCSC{Tv}(A::IMatrix) where {Tv} = SparseMatrixCSC{Tv,Int}(A) SparseMatrixCSC(A::IMatrix{T}) where {T} = SparseMatrixCSC{T,Int}(I, A.n, A.n) -function SparseMatrixCSC(M::PermMatrix) +function SparseMatrixCSC(M::AbstractPermMatrix) n = size(M, 1) - #SparseMatrixCSC(n, n, collect(1:n+1), M.perm, M.vals) - order = invperm(M.perm) - SparseMatrixCSC(n, n, collect(1:n+1), order, M.vals[order]) + MC = PermMatrixCSC(M) + SparseMatrixCSC(n, n, collect(1:n+1), MC.perm, MC.vals) end -@static if VERSION < v"1.3-" - - function SparseMatrixCSC(D::Diagonal{T}) where {T} - m = length(D.diag) - return SparseMatrixCSC(m, m, Vector(1:(m+1)), Vector(1:m), Vector{T}(D.diag)) - end - -end - -SparseMatrixCSC{Tv,Ti}(M::PermMatrix{Tv,Ti}) where {Tv,Ti} = SparseMatrixCSC(M) +SparseMatrixCSC{Tv,Ti}(M::AbstractPermMatrix{Tv,Ti}) where {Tv,Ti} = SparseMatrixCSC(M) SparseMatrixCSC(coo::SparseMatrixCOO) = sparse(coo.is, coo.js, coo.vs, coo.m, coo.n) ################## To Dense ###################### @@ -47,7 +37,15 @@ function Matrix{T}(X::PermMatrix) where {T} end return Mf end -Matrix(X::PermMatrix{T}) where {T} = Matrix{T}(X) +function Matrix{T}(X::PermMatrixCSC) where {T} + n = size(X, 1) + Mf = zeros(T, n, n) + @simd for j = 1:n + @inbounds Mf[X.perm[j], j] = X.vals[j] + end + return Mf +end +Matrix(X::AbstractPermMatrix{T}) where {T} = Matrix{T}(X) function Matrix(coo::SparseMatrixCOO{T}) where {T} mat = zeros(T, coo.m, coo.n) @@ -58,13 +56,22 @@ function Matrix(coo::SparseMatrixCOO{T}) where {T} end ################## To PermMatrix ###################### -PermMatrix{Tv,Ti}(A::IMatrix) where {Tv,Ti} = - PermMatrix{Tv,Ti}(Vector{Ti}(1:A.n), ones(Tv, A.n)) -PermMatrix{Tv}(X::IMatrix) where {Tv} = PermMatrix{Tv,Int}(X) -PermMatrix(X::IMatrix{T}) where {T} = PermMatrix{T,Int}(X) - -PermMatrix{Tv,Ti}(A::PermMatrix) where {Tv,Ti} = - PermMatrix(Vector{Ti}(A.perm), Vector{Tv}(A.vals)) +function PermMatrix(pc::PermMatrixCSC) + order = fast_invperm(pc.perm) + PermMatrix(order, pc.vals[order]) +end +function PermMatrixCSC(pc::PermMatrix) + order = fast_invperm(pc.perm) + PermMatrixCSC(order, pc.vals[order]) +end +for MT in [:PermMatrix, :PermMatrixCSC] + @eval $MT{Tv,Ti}(A::IMatrix) where {Tv,Ti} = + $MT{Tv,Ti}(Vector{Ti}(1:A.n), ones(Tv, A.n)) + @eval $MT{Tv}(X::IMatrix) where {Tv} = $MT{Tv,Int}(X) + @eval $MT(X::IMatrix{T}) where {T} = $MT{T,Int}(X) + @eval $MT{Tv,Ti}(A::$MT) where {Tv,Ti} = + $MT(Vector{Ti}(A.perm), Vector{Tv}(A.vals)) +end # NOTE: bad implementation! function _findnz(A::AbstractMatrix) @@ -75,23 +82,26 @@ end _findnz(A::AbstractSparseArray) = findnz(A) function PermMatrix{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti} + PermMatrix(PermMatrixCSC(A)) +end +function PermMatrixCSC{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti} i, j, v = _findnz(A) j == collect(1:size(A, 2)) || throw(ArgumentError("This is not a PermMatrix")) - order = invperm(i) - PermMatrix{Tv,Ti}(Vector{Ti}(order), Vector{Tv}(v[order])) + PermMatrixCSC{Tv,Ti}(Vector{Ti}(i), Vector{Tv}(v)) end -PermMatrix(A::AbstractMatrix{T}) where {T} = PermMatrix{T,Int}(A) -PermMatrix(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = PermMatrix{Tv,Ti}(A) # inherit indice type -PermMatrix{Tv,Ti}(A::Diagonal{Tv}) where {Tv,Ti} = - PermMatrix(Vector{Ti}(1:size(A, 1)), A.diag) -#PermMatrix(A::Diagonal{T}) where T = PermMatrix{T, Int}(A) -# lazy implementation -function PermMatrix{Tv,Ti,Vv,Vi}( - A::AbstractMatrix, -) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} - pm = PermMatrix(PermMatrix{Tv,Ti}(A)) - PermMatrix(Vi(pm.perm), Vv(pm.vals)) + +for MT in [:PermMatrix, :PermMatrixCSC] + @eval $MT(A::AbstractMatrix{T}) where {T} = $MT{T,Int}(A) + @eval $MT(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = $MT{Tv,Ti}(A) # inherit indice type + @eval $MT{Tv,Ti}(A::Diagonal{Tv}) where {Tv,Ti} = $MT(Vector{Ti}(1:size(A, 1)), A.diag) + @eval function $MT{Tv,Ti,Vv,Vi}( + A::AbstractMatrix, + ) where {Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} + pm = $MT(PermMatrix{Tv,Ti}(A)) + PermMatrix(Vi(pm.perm), Vv(pm.vals)) + end end +# lazy implementation ############## To SparseMatrixCOO ############## function SparseMatrixCOO(A::Matrix{Tv}; atol = 1e-12) where {Tv} @@ -111,5 +121,5 @@ function SparseMatrixCOO(A::Matrix{Tv}; atol = 1e-12) where {Tv} SparseMatrixCOO(is, js, vs, m, n) end -Base.convert(T::Type{<:PermMatrix}, m::AbstractMatrix) = m isa T ? m : T(m) +Base.convert(T::Type{<:AbstractPermMatrix}, m::AbstractMatrix) = m isa T ? m : T(m) Base.convert(T::Type{<:IMatrix}, m::AbstractMatrix) = m isa T ? m : T(m) diff --git a/src/iterate.jl b/src/iterate.jl index b5fcbad..d2d61ca 100644 --- a/src/iterate.jl +++ b/src/iterate.jl @@ -7,7 +7,7 @@ Base.length(nz::IterNz{<:AbstractSparseMatrix}) = nnz(nz.A) Base.length(nz::IterNz{<:Adjoint}) = length(IterNz(nz.A.parent)) Base.length(nz::IterNz{<:Transpose}) = length(IterNz(nz.A.parent)) Base.length(nz::IterNz{<:Diagonal}) = size(nz.A, 1) -Base.length(nz::IterNz{<:PermMatrix}) = size(nz.A, 1) +Base.length(nz::IterNz{<:AbstractPermMatrix}) = size(nz.A, 1) Base.length(nz::IterNz{<:IMatrix}) = size(nz.A, 1) Base.eltype(nz::IterNz) = eltype(nz.A) @@ -44,6 +44,17 @@ function Base.iterate(it::IterNz{<:PermMatrix}, state) return (state, (@inbounds it.A.perm[state]), (@inbounds it.A.vals[state])), state end +# PermMatrixCSC +function Base.iterate(it::IterNz{<:PermMatrixCSC}) + 0 == length(it) && return nothing + return ((@inbounds it.A.perm[1]), 1, (@inbounds it.A.vals[1])), 1 +end +function Base.iterate(it::IterNz{<:PermMatrixCSC}, state) + state == length(it) && return nothing + state += 1 + return ((@inbounds it.A.perm[state]), state, (@inbounds it.A.vals[state])), state +end + # AbstractMatrix function Base.iterate(it::IterNz{<:AbstractMatrix}) 0 == length(it) && return nothing diff --git a/src/kronecker.jl b/src/kronecker.jl index 81ed79a..eb2b608 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -32,10 +32,10 @@ LinearAlgebra.kron(A::IMatrix{<:Number}, B::Diagonal{<:Number}) = A.n == 1 ? B : LinearAlgebra.kron(B::Diagonal{<:Number}, A::IMatrix) = A.n == 1 ? B : Diagonal(irepeat(B.diag, A.n)) ####### diagonal kron ######## -LinearAlgebra.kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrix(A), B) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrix(A), B) -LinearAlgebra.kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) +LinearAlgebra.kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B)) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrixCSC(A), B) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrixCSC(A), B) +LinearAlgebra.kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B)) function LinearAlgebra.kron(A::AbstractMatrix{Tv}, B::IMatrix) where {Tv<:Number} B.n == 1 && return A @@ -127,7 +127,7 @@ function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::IMatrix) where {T<:Number} SparseMatrixCSC(mA * B.n, nA * B.n, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{T}, B::IMatrix) where {T<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{T}, B::IMatrix) where {T<:Number} nA = size(A, 1) nB = size(B, 1) nB == 1 && return A @@ -142,10 +142,10 @@ function LinearAlgebra.kron(A::PermMatrix{T}, B::IMatrix) where {T<:Number} vals[start+j] = val end end - PermMatrix(perm, vals) + basetype(A)(perm, vals) end -function LinearAlgebra.kron(A::IMatrix, B::PermMatrix{Tv,Ti}) where {Tv<:Number,Ti<:Integer} +function LinearAlgebra.kron(A::IMatrix, B::AbstractPermMatrix{Tv,Ti}) where {Tv<:Number,Ti<:Integer} nA = size(A, 1) nB = size(B, 1) nA == 1 && return B @@ -158,14 +158,14 @@ function LinearAlgebra.kron(A::IMatrix, B::PermMatrix{Tv,Ti}) where {Tv<:Number, vals[start+j] = B.vals[j] end end - PermMatrix(perm, vals) + basetype(B)(perm, vals) end - -function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<:Number,Tb<:Number} +function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::AbstractPermMatrix{Tb}) where {Tv<:Number,Tb<:Number} mA, nA = size(A) nB = size(B, 1) - perm = fast_invperm(B.perm) + BC = PermMatrixCSC(B) + perm, vals = BC.perm, BC.vals nzval = Vector{promote_type(Tv, Tb)}(undef, mA * nA * nB) rowval = Vector{Int}(undef, mA * nA * nB) colptr = collect(1:mA:nA*nB*mA+1) @@ -173,7 +173,7 @@ function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<: @inbounds for j = 1:nA @inbounds for j2 = 1:nB p2 = perm[j2] - val2 = B.vals[p2] + val2 = vals[j2] ir = p2 @inbounds @simd for i = 1:mA nzval[z] = A[i, j] * val2 # merge @@ -186,18 +186,18 @@ function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<: SparseMatrixCSC(mA * nB, nA * nB, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number,Ta<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number,Ta<:Number} mB, nB = size(B) nA = size(A, 1) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = Vector{promote_type(Ta, Tb)}(undef, mB * nA * nB) rowval = Vector{Int}(undef, mB * nA * nB) colptr = collect(1:mB:nA*nB*mB+1) z = 0 @inbounds for j = 1:nA - colbase = (j - 1) * nB p1 = perm[j] - val2 = A.vals[p1] + val2 = vals[j] ir = (p1 - 1) * mB for j2 = 1:nB @inbounds @simd for i2 = 1:mB @@ -210,7 +210,8 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<: SparseMatrixCSC(nA * mB, nA * nB, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number}) +function LinearAlgebra.kron(A::AbstractPermMatrix{<:Number}, B::AbstractPermMatrix{<:Number}) + @assert basetype(A) == basetype(B) nA = size(A, 1) nB = size(B, 1) vals = kron(A.vals, B.vals) @@ -222,17 +223,18 @@ function LinearAlgebra.kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number}) perm[start+j] = permAi + B.perm[j] end end - PermMatrix(perm, vals) + basetype(A)(perm, vals) end -LinearAlgebra.kron(A::PermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::PermMatrix{<:Number}) = kron(PermMatrix(A), B) +LinearAlgebra.kron(A::AbstractPermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, basetype(A)(B)) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::AbstractPermMatrix{<:Number}) = kron(basetype(B)(A), B) -function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number,Tb<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number,Tb<:Number} nA = size(A, 1) mB, nB = size(B) nV = nnz(B) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = Vector{promote_type(Ta, Tb)}(undef, nA * nV) rowval = Vector{Int}(undef, nA * nV) colptr = Vector{Int}(undef, nA * nB + 1) @@ -240,7 +242,7 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta @inbounds @simd for i = 1:nA start_row = (i - 1) * nV start_ri = (perm[i] - 1) * mB - v0 = A.vals[perm[i]] + v0 = vals[i] @inbounds @simd for j = 1:nV nzval[start_row+j] = B.nzval[j] * v0 rowval[start_row+j] = B.rowval[j] + start_ri @@ -254,11 +256,12 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta SparseMatrixCSC(mB * nA, nB * nA, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<:Number,Tb<:Number} +function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::AbstractPermMatrix{Tb}) where {T<:Number,Tb<:Number} nB = size(B, 1) mA, nA = size(A) nV = nnz(A) - perm = fast_invperm(B.perm) + BC = PermMatrixCSC(B) + perm, vals = BC.perm, BC.vals rowval = Vector{Int}(undef, nB * nV) colptr = Vector{Int}(undef, nA * nB + 1) nzval = Vector{promote_type(T, Tb)}(undef, nB * nV) @@ -269,7 +272,7 @@ function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<: rend = A.colptr[i+1] - 1 @inbounds for k = 1:nB irow = perm[k] - bval = B.vals[irow] + bval = vals[k] irow_nB = irow - nB @inbounds @simd for r = rstart:rend rowval[z] = A.rowval[r] * nB + irow_nB diff --git a/src/linalg.jl b/src/linalg.jl index 81ef4a8..cf696cb 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -3,15 +3,14 @@ Base.inv(M::IMatrix) = M LinearAlgebra.det(M::IMatrix) = 1 LinearAlgebra.diag(M::IMatrix{T}) where {T} = ones(T, M.n) LinearAlgebra.logdet(M::IMatrix) = 0 -Base.sqrt(x::PermMatrix) = sqrt(Matrix(x)) +Base.sqrt(x::AbstractPermMatrix) = sqrt(Matrix(x)) Base.sqrt(x::IMatrix) = x -Base.exp(x::PermMatrix) = exp(Matrix(x)) +Base.exp(x::AbstractPermMatrix) = exp(Matrix(x)) Base.exp(x::IMatrix) = ℯ * x -#det(M::PermMatrix) = parity(M.perm)*prod(M.vals) -function Base.inv(M::PermMatrix) +function Base.inv(M::AbstractPermMatrix) new_perm = fast_invperm(M.perm) - return PermMatrix(new_perm, 1.0 ./ M.vals[new_perm]) + return basetype(M)(new_perm, 1.0 ./ M.vals[new_perm]) end ####### multiply ########### @@ -24,7 +23,7 @@ Base.:*(A::IMatrix, B::AbstractVector) = ) for MATTYPE in - [:AbstractMatrix, :StridedMatrix, :Diagonal, :SparseMatrixCSC, :Matrix, :PermMatrix] + [:AbstractMatrix, :StridedMatrix, :Diagonal, :SparseMatrixCSC, :Matrix, :AbstractPermMatrix] @eval Base.:*(A::IMatrix, B::$MATTYPE) = A.n == size(B, 1) ? B : throw( @@ -61,12 +60,12 @@ Base.:*(A::IMatrix, B::IMatrix) = ########## Multiplication ############# -function LinearAlgebra.mul!(Y::AbstractVector, A::PermMatrix, X::AbstractVector, alpha::Number, beta::Number) - length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match PermMatrix A")) - length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match PermMatrix A")) +function LinearAlgebra.mul!(Y::AbstractVector, A::AbstractPermMatrix, X::AbstractVector, alpha::Number, beta::Number) + length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match permutation matrix A")) + length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match permutation matrix A")) - @inbounds for I in eachindex(X) - Y[I] = A.vals[I] * X[A.perm[I]] * alpha + beta * Y[I] + @inbounds for (i, j, p) in IterNz(A) + Y[i] = p * X[j] * alpha + beta * Y[i] end return Y end @@ -75,54 +74,56 @@ end function Base.:*(D::Diagonal{Td}, A::PermMatrix{Ta}) where {Td,Ta} PermMatrix(A.perm, A.vals .* D.diag) end - +function Base.:*(D::Diagonal{Td}, A::PermMatrixCSC{Ta}) where {Td,Ta} + PermMatrixCSC(A.perm, view(D.diag, A.perm) .* A.vals) +end function Base.:*(A::PermMatrix{Ta}, D::Diagonal{Td}) where {Td,Ta} PermMatrix(A.perm, A.vals .* view(D.diag, A.perm)) end +function Base.:*(A::PermMatrixCSC{Ta}, D::Diagonal{Td}) where {Td,Ta} + PermMatrixCSC(A.perm, A.vals .* D.diag) +end # to self function Base.:*(A::PermMatrix, B::PermMatrix) + @assert basetype(A) == basetype(B) + size(A, 1) == size(B, 1) || throw(DimensionMismatch()) + basetype(A)(B.perm[A.perm], A.vals .* view(B.vals, A.perm)) +end + +function Base.:*(A::PermMatrixCSC, B::PermMatrixCSC) + @assert basetype(A) == basetype(B) size(A, 1) == size(B, 1) || throw(DimensionMismatch()) - PermMatrix(B.perm[A.perm], A.vals .* view(B.vals, A.perm)) + basetype(A)(A.perm[B.perm], B.vals .* view(A.vals, B.perm)) end # to matrix -function Base.:*(A::PermMatrix, X::AbstractMatrix) +function LinearAlgebra.mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number) size(X, 1) == size(A, 2) || throw(DimensionMismatch()) - return A.vals .* view(X,A.perm,:) # this may be inefficient for sparse CSC matrix. + AR = PermMatrix(A) + C .= C .* beta .+ AR.vals .* view(X, AR.perm, :) .* alpha end - -function Base.:*(X::AbstractMatrix, A::PermMatrix) - mX, nX = size(X) - nX == size(A, 1) || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) - return transpose(view(A.vals, perm)) .* view(X, :, perm) +function LinearAlgebra.mul!(C::AbstractMatrix, X::AbstractMatrix, A::AbstractPermMatrix, alpha::Number, beta::Number) + size(X, 2) == size(A, 1) || throw(DimensionMismatch()) + AC = PermMatrixCSC(A) + C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, AC.perm) .* alpha end # NOTE: this is just a temperory fix for v0.7. We should overload mul! in # the future (when we start to drop v0.6) to enable buildin lazy evaluation. -Base.:*(x::Adjoint{<:Any,<:AbstractVector}, D::PermMatrix) = Matrix(x) * D -Base.:*(x::Transpose{<:Any,<:AbstractVector}, D::PermMatrix) = Matrix(x) * D -Base.:*(A::Adjoint{<:Any,<:AbstractArray}, D::PermMatrix) = Adjoint(adjoint(D) * parent(A)) -Base.:*(A::Transpose{<:Any,<:AbstractArray}, D::PermMatrix) = Transpose(transpose(D) * parent(A)) -Base.:*(A::Adjoint{<:Any,<:PermMatrix}, D::PermMatrix) = adjoint(parent(A)) * D -Base.:*(A::Transpose{<:Any,<:PermMatrix}, D::PermMatrix) = transpose(parent(A)) * D -Base.:*(A::PermMatrix, D::Adjoint{<:Any,<:PermMatrix}) = A * adjoint(parent(D)) -Base.:*(A::PermMatrix, D::Transpose{<:Any,<:PermMatrix}) = A * transpose(parent(D)) - -# for MAT in [:AbstractArray, :Matrix, :SparseMatrixCSC, :PermMatrix] -# @eval begin -# *(A::Adjoint{<:Any, <:$MAT}, D::PermMatrix) = copy(A) * D -# *(A::Transpose{<:Any, <:$MAT}, D::PermMatrix) = copy(A) * D -# *(A::PermMatrix, D::Adjoint{<:Any, <:$MAT}) = A * copy(D) -# *(A::PermMatrix, D::Transpose{<:Any, <:$MAT}) = A * copy(D) -# end -# end +Base.:*(x::Adjoint{<:Any,<:AbstractVector}, D::AbstractPermMatrix) = Matrix(x) * D +Base.:*(x::Transpose{<:Any,<:AbstractVector}, D::AbstractPermMatrix) = Matrix(x) * D +Base.:*(A::Adjoint{<:Any,<:AbstractArray}, D::AbstractPermMatrix) = Adjoint(adjoint(D) * parent(A)) +Base.:*(A::Transpose{<:Any,<:AbstractArray}, D::AbstractPermMatrix) = Transpose(transpose(D) * parent(A)) +Base.:*(A::Adjoint{<:Any,<:AbstractPermMatrix}, D::AbstractPermMatrix) = adjoint(parent(A)) * D +Base.:*(A::Transpose{<:Any,<:AbstractPermMatrix}, D::AbstractPermMatrix) = transpose(parent(A)) * D +Base.:*(A::AbstractPermMatrix, D::Adjoint{<:Any,<:AbstractPermMatrix}) = A * adjoint(parent(D)) +Base.:*(A::AbstractPermMatrix, D::Transpose{<:Any,<:AbstractPermMatrix}) = A * transpose(parent(D)) ############### Transpose, Adjoint for IMatrix ############### for MAT in - [:AbstractArray, :AbstractVector, :Matrix, :SparseMatrixCSC, :PermMatrix, :IMatrix] + [:AbstractArray, :AbstractVector, :Matrix, :SparseMatrixCSC, :AbstractPermMatrix, :IMatrix] @eval Base.:*(A::Adjoint{<:Any,<:$MAT}, D::IMatrix) = Adjoint(D * parent(A)) @eval Base.:*(A::Transpose{<:Any,<:$MAT}, D::IMatrix) = Transpose(D * parent(A)) if MAT != :AbstactVector @@ -132,17 +133,18 @@ for MAT in end # to sparse -function Base.:*(A::PermMatrix, X::SparseMatrixCSC) +function Base.:*(A::AbstractPermMatrix, X::SparseMatrixCSC) nA = size(A, 1) mX, nX = size(X) mX == nA || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = similar(X.nzval) rowval = similar(X.rowval) @inbounds for j = 1:nX @inbounds for k = X.colptr[j]:X.colptr[j+1]-1 r = perm[X.rowval[k]] - nzval[k] = X.nzval[k] * A.vals[r] + nzval[k] = X.nzval[k] * vals[X.rowval[k]] rowval[k] = r end end @@ -150,11 +152,12 @@ function Base.:*(A::PermMatrix, X::SparseMatrixCSC) SparseMatrixCSC(sp')' end -function Base.:*(X::SparseMatrixCSC, A::PermMatrix) +function Base.:*(X::SparseMatrixCSC, A::AbstractPermMatrix) nA = size(A, 1) mX, nX = size(X) nX == nA || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = similar(X.nzval) colptr = similar(X.colptr) rowval = similar(X.rowval) @@ -162,7 +165,7 @@ function Base.:*(X::SparseMatrixCSC, A::PermMatrix) z = 1 @inbounds for j = 1:nA pk = perm[j] - va = A.vals[pk] + va = vals[j] @inbounds @simd for k = X.colptr[pk]:X.colptr[pk+1]-1 nzval[z] = X.nzval[k] * va rowval[z] = X.rowval[k] @@ -182,7 +185,7 @@ Base.:*(B::Int, A::SparseMatrixCOO) = lmul!(B, copy(A)) Base.:/(A::SparseMatrixCOO, B::Int) = rdiv!(copy(A), B) Base.:-(ii::IMatrix) = (-1) * ii -Base.:-(pm::PermMatrix) = (-1) * pm +Base.:-(pm::AbstractPermMatrix) = (-1) * pm for FUNC in [:randn!, :rand!] @eval function Random.$FUNC(m::Diagonal) @@ -195,7 +198,7 @@ for FUNC in [:randn!, :rand!] return m end - @eval function Random.$FUNC(m::PermMatrix) + @eval function Random.$FUNC(m::AbstractPermMatrix) $FUNC(m.vals) return m end diff --git a/src/promotions.jl b/src/promotions.jl index 426b169..b6bf658 100644 --- a/src/promotions.jl +++ b/src/promotions.jl @@ -3,18 +3,20 @@ Base.promote_rule(::Type{SparseMatrixCSC{Tv,Ti}}, ::Type{Matrix{T}}) where {Tv,T Matrix{promote_type(T, Tv)} # IMatrix -Base.promote_rule( - ::Type{IMatrix{T}}, - ::Type{PermMatrix{Tv,Ti,Vv,Vi}}, -) where {T,Tv,Ti,Vv,Vi} = (TT = promote_type(T, Tv); PermMatrix{TT,Ti,Vector{TT},Vi}) +for MT in [:PermMatrix, :PermMatrixCSC] + @eval Base.promote_rule( + ::Type{IMatrix{T}}, + ::Type{$MT{Tv,Ti,Vv,Vi}}, + ) where {T,Tv,Ti,Vv,Vi} = (TT = promote_type(T, Tv); $MT{TT,Ti,Vector{TT},Vi}) +end Base.promote_rule(::Type{IMatrix{T}}, ::Type{SparseMatrixCSC{Tv,Ti}}) where {T,Tv,Ti} = SparseMatrixCSC{promote_type(T, Tv),Ti} Base.promote_rule(::Type{IMatrix{TA}}, ::Type{Matrix{TB}}) where {TA,TB} = Array{TB,2} # PermMatrix Base.promote_rule( - ::Type{PermMatrix{TvA,TiA}}, + ::Type{<:AbstractPermMatrix{TvA,TiA}}, ::Type{SparseMatrixCSC{TvB,TiB}}, ) where {TvA,TiA,TvB,TiB} = SparseMatrixCSC{promote_type(TvA, TvB),promote_type(TiA, TiB)} -Base.promote_rule(::Type{PermMatrix{Tv,Ti}}, ::Type{Matrix{T}}) where {Tv,Ti,T} = +Base.promote_rule(::Type{<:AbstractPermMatrix{Tv,Ti}}, ::Type{Matrix{T}}) where {Tv,Ti,T} = Array{promote_type(Tv, T),2} diff --git a/src/staticize.jl b/src/staticize.jl index e75c80b..fdfce58 100644 --- a/src/staticize.jl +++ b/src/staticize.jl @@ -4,6 +4,8 @@ const SDDiagonal{T} = Union{Diagonal{T},SDiagonal{N,T} where N} const SDVector{T} = Union{Vector{T},SVector{N,T} where N} const SDPermMatrix{Tv,Ti<:Integer} = PermMatrix{Tv,Ti,<:SDVector{Tv},<:SDVector{Ti}} const SPermMatrix{N,Tv,Ti<:Integer} = PermMatrix{Tv,Ti,<:SVector{N,Tv},<:SVector{N,Ti}} +const SDPermMatrixCSC{Tv,Ti<:Integer} = PermMatrixCSC{Tv,Ti,<:SDVector{Tv},<:SDVector{Ti}} +const SPermMatrixCSC{N,Tv,Ti<:Integer} = PermMatrixCSC{Tv,Ti,<:SVector{N,Tv},<:SVector{N,Ti}} const SDSparseMatrixCSC{Tv,Ti} = Union{SparseMatrixCSC{Tv,Ti},SSparseMatrixCSC{Tv,Ti}} ######### staticize ########## @@ -20,6 +22,8 @@ staticize(A::AbstractVector) = SVector{length(A)}(A) staticize(A::Diagonal) = SDiagonal{size(A, 1)}((A.diag...,)) staticize(A::PermMatrix) = PermMatrix(SVector{size(A, 1)}(A.perm), SVector{size(A, 1)}(A.vals)) +staticize(A::PermMatrixCSC) = + PermMatrixCSC(SVector{size(A, 1)}(A.perm), SVector{size(A, 1)}(A.vals)) function staticize(A::SparseMatrixCSC) iszero(A) && return SSparseMatrixCSC( A.m, @@ -48,6 +52,7 @@ dynamicize(A::SMatrix) = Matrix(A) dynamicize(A::SVector) = Vector(A) dynamicize(A::SDiagonal) = Diagonal(Vector(A.diag)) dynamicize(A::PermMatrix) = PermMatrix(Vector(A.perm), Vector(A.vals)) +dynamicize(A::PermMatrixCSC) = PermMatrixCSC(Vector(A.perm), Vector(A.vals)) function dynamicize(A::SSparseMatrixCSC) SparseMatrixCSC(A.m, A.n, Vector(A.colptr), Vector(A.rowval), Vector(A.nzval)) end diff --git a/test/IMatrix.jl b/test/IMatrix.jl index 7581567..67284f8 100644 --- a/test/IMatrix.jl +++ b/test/IMatrix.jl @@ -52,12 +52,6 @@ end end @test imag(p1) == zeros(4, 4) @test p1' == Matrix(I, 4, 4) - - # This will be lazy evaluated in 0.7+ - @static if VERSION < v"0.7-" - @test typeof(p1') == typeof(p1) - end - @test ishermitian(p1) end diff --git a/test/PermMatrix.jl b/test/PermMatrix.jl index f5479a3..c43afcd 100644 --- a/test/PermMatrix.jl +++ b/test/PermMatrix.jl @@ -13,7 +13,13 @@ sp = sprand(4, 4, 0.3) v = [0.5, 0.3im, 0.2, 1.0] @testset "basic" begin + @test_throws DimensionMismatch PermMatrix([1, 4, 2, 3], [0.1, 0.2, 0.4im]) + @test_throws ArgumentError size(p1, 0) + @test size(p1, 3) == 1 + @test [zip(findnz(p1)...)...] == [IterNz(p1)...] @test p1 == copy(p1) + @test hash(p1) == hash(copy(p1)) + @test hash(p1) != hash(p2) @test eltype(p1) == ComplexF64 @test eltype(p2) == Float64 @test eltype(p3) == Float64 @@ -29,6 +35,7 @@ v = [0.5, 0.3im, 0.2, 1.0] @test p1[1, 1] === 0.1 + 0.0im copyto!(p0, p1) @test p0 == p1 + @test PermMatrix([0.0 -1.0im; 1.0im 0.0im]) ≈ [0.0 -1.0im; 1.0im 0.0im] end @testset "linalg" begin @@ -88,7 +95,7 @@ end @testset "setindex" begin pm = PermMatrix([3, 2, 4, 1], [0.0, 0.0, 0.0, 0.0]) pm[3, 4] = 1.0 - @test_throws BoundsError pm[3, 1] = 1.0 + @test_throws AssertionError pm[3, 1] = 1.0 @test pm[3, 4] == 1.0 end @@ -102,4 +109,4 @@ end A = randn(ComplexF64, 4, 4) pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3]) @test A * pm ≈ A * Matrix(pm) -end \ No newline at end of file +end diff --git a/test/PermMatrixCSC.jl b/test/PermMatrixCSC.jl new file mode 100644 index 0000000..0ad045b --- /dev/null +++ b/test/PermMatrixCSC.jl @@ -0,0 +1,114 @@ +using Test, Random +import LuxurySparse: PermMatrixCSC, pmcscrand +import LuxurySparse +using SparseArrays: sprand, SparseMatrixCSC +using LinearAlgebra + +Random.seed!(2) +p1 = PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5]) +p2 = PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5]) +#p3 = PermMatrix([4,1,2,3],[0.5, 0.4im, 0.3, 0.2]) +p3 = pmcscrand(4) +sp = sprand(4, 4, 0.3) +v = [0.5, 0.3im, 0.2, 1.0] + +@testset "basic" begin + @test_throws DimensionMismatch PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im]) + @test_throws ArgumentError size(p1, 0) + @test size(p1, 3) == 1 + @test [zip(findnz(p1)...)...] == [IterNz(p1)...] + @test p1 == copy(p1) + @test hash(p1) == hash(copy(p1)) + @test hash(p1) != hash(p2) + @test eltype(p1) == ComplexF64 + @test eltype(p2) == Float64 + @test eltype(p3) == Float64 + @test size(p1) == (4, 4) + @test size(p3) == (4, 4) + @test size(p1, 1) == size(p1, 2) == 4 + @test Matrix(p1) ≈ transpose([0.1 0 0 0; 0 0 0 0.2; 0 0.4im 0 0; 0 0 0.5 0]) + p0 = similar(p1) + @test p0.perm == p1.perm + @test p0.perm !== p1.perm + @test p0.vals !== p1.vals + @test p1[2, 2] === 0.0im + @test p1[1, 1] === 0.1 + 0.0im + copyto!(p0, p1) + @test p0 == p1 + @test PermMatrix([0.0 -1.0im; 1.0im 0.0im]) ≈ [0.0 -1.0im; 1.0im 0.0im] +end + +@testset "linalg" begin + @test inv(p1) ≈ inv(Matrix(p1)) + @test transpose(p1) ≈ transpose(Matrix(p1)) + @test inv(p1) * p1 ≈ Matrix(I, 4, 4) + @test p1 * transpose(p1) ≈ diagm(0 => p1.vals[invperm(p1.perm)] .^ 2) + #@test p1*adjoint(p1) == diagm(0=>abs.(p1.vals).^2) + #@test all(isapprox.(adjoint(p3), transpose(conj(Matrix(p3))))) + @test p1 * p1' == diagm(0 => abs.(p1.vals[invperm(p1.perm)]) .^ 2) + @test all(isapprox.(p3', transpose(conj(Matrix(p3))))) +end + +@testset "mul" begin + @test p3 * p2 ≈ SparseMatrixCSC(p3) * p2 ≈ Matrix(p3) * p2 + + # Multiply vector + @test p3 * v == Matrix(p3) * v + @test v' * p3 == v' * Matrix(p3) + @test vec(collect(1:4)' * p3) ≈ p3.perm .* p3.vals + + # Diagonal matrices + Dv = Diagonal(v) + @test p3 * Dv == Matrix(p3) * Dv + @test Dv * p3 == Dv * Matrix(p3) +end + +@testset "elementary" begin + @test all(isapprox.(conj(p1), conj(Matrix(p1)))) + @test all(isapprox.(real(p1), real(Matrix(p1)))) + @test all(isapprox.(imag(p1), imag(Matrix(p1)))) +end + +@testset "basicmath" begin + @test p1 * 2 == Matrix(p1) * 2 + @test p1 / 2 == Matrix(p1) / 2 +end + +@testset "memorysafe" begin + @test p1 == PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5]) + @test p2 == PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5]) + @test v == [0.5, 0.3im, 0.2, 1.0] +end + +@testset "sparse" begin + Random.seed!(2) + pm = pmrand(10) + out = zeros(10, 10) + @test LuxurySparse.nnz(pm) == 10 + @test LuxurySparse.findnz(pm)[3] == pm.vals +end + +@testset "identity sparse" begin + p1 = Diagonal(randn(10)) + @test LuxurySparse.nnz(p1) == 10 + @test LuxurySparse.findnz(p1)[3] == p1.diag +end + +@testset "setindex" begin + pm = PermMatrix([3, 2, 4, 1], [0.0, 0.0, 0.0, 0.0]) + pm[3, 4] = 1.0 + @test_throws AssertionError pm[3, 1] = 1.0 + @test pm[3, 4] == 1.0 +end + +@testset "broadcast" begin + pm = PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3]) + res = pm .* 3im + @test res == PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3] .* 3im) && res isa PermMatrix +end + +@testset "fix dense-perm multiplication" begin + A = randn(ComplexF64, 4, 4) + pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3]) + @test A * pm ≈ A * Matrix(pm) +end diff --git a/test/broadcast.jl b/test/broadcast.jl index 7aa3425..4879b70 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -5,7 +5,7 @@ using SparseArrays @testset "broadcast *" begin - @testset "Diagonal .* $(nameof(typeof(M)))" for M in Any[pmrand(3)] + @testset "Diagonal .* $(nameof(typeof(M)))" for M in [[pmrand(3)]..., pmcscrand(3)] M1 = Diagonal(rand(3)) out = M1 .* M @test typeof(out) <: Diagonal @@ -29,11 +29,21 @@ using SparseArrays out = M .* M1 @test typeof(out) <: PermMatrix @test out ≈ M .* Matrix(M1) + + M1 = pmcscrand(3) + out = M1 .* M + @test typeof(out) <: PermMatrixCSC + @test out ≈ Matrix(M1) .* M + + out = M .* M1 + !(M isa PermMatrix) && @test typeof(out) <: PermMatrixCSC + @test out ≈ M .* Matrix(M1) end @testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[ rand(3, 3), pmrand(3), + pmcscrand(3), sprand(3, 3, 0.5), ] eye = IMatrix(3) @@ -77,6 +87,10 @@ end M1 = pmrand(3) @test M1 .- M ≈ Matrix(M1) .- M @test M .- M1 ≈ M .- Matrix(M1) + + M1 = pmcscrand(3) + @test M1 .- M ≈ Matrix(M1) .- M + @test M .- M1 ≈ M .- Matrix(M1) end @testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[ diff --git a/test/iterate.jl b/test/iterate.jl index c3c9d39..81b7ca0 100644 --- a/test/iterate.jl +++ b/test/iterate.jl @@ -3,6 +3,7 @@ using Test, LuxurySparse, SparseArrays, LinearAlgebra @testset "iterate" begin for M in Any[ pmrand(10), + pmcscrand(10), Diagonal(randn(10)), IMatrix(10), randn(10, 10), diff --git a/test/kronecker.jl b/test/kronecker.jl index a2a8792..9df3c20 100644 --- a/test/kronecker.jl +++ b/test/kronecker.jl @@ -1,5 +1,5 @@ using Test, Random, SparseArrays, LinearAlgebra -import LuxurySparse: IMatrix, PermMatrix +import LuxurySparse: IMatrix, PermMatrix, PermMatrixCSC, basetype, AbstractPermMatrix @testset "kron" begin Random.seed!(2) @@ -8,12 +8,15 @@ import LuxurySparse: IMatrix, PermMatrix sp = sprand(ComplexF64, 4, 4, 0.5) ds = rand(ComplexF64, 4, 4) pm = PermMatrix([2, 3, 4, 1], randn(4)) - pm = PermMatrix([2, 3, 4, 1], randn(4)) + pmc = PermMatrixCSC([2, 3, 4, 1], randn(4)) v = [0.5, 0.3im, 0.2, 1.0] dv = Diagonal(v) - for source in Any[p1, sp, ds, dv, pm], - target in Any[p1, sp, ds, dv, pm] + for source in Any[p1, sp, ds, dv, pm, pmc], + target in Any[p1, sp, ds, dv, pm, pmc] + if source isa AbstractPermMatrix && target isa AbstractPermMatrix && basetype(source) != basetype(target) + continue + end lres = kron(source, target) rres = kron(target, source) flres = kron(Matrix(source), Matrix(target)) diff --git a/test/runtests.jl b/test/runtests.jl index 8ee7851..17fca67 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ end @testset "PermMatrix" begin include("PermMatrix.jl") + include("PermMatrixCSC.jl") end @testset "SparseMatrixCOO" begin diff --git a/test/staticize.jl b/test/staticize.jl index 0ec14f5..93edac6 100644 --- a/test/staticize.jl +++ b/test/staticize.jl @@ -9,6 +9,7 @@ using StaticArrays: SVector, SMatrix Random.seed!(2) @testset "staticize" begin + @test staticize(1) == 1 # permmatrix m = pmrand(ComplexF64, 4) sm = m |> staticize @@ -24,6 +25,22 @@ Random.seed!(2) @test dm.perm == m.perm @test dm.vals == m.vals + # permmatrixcsc + m = pmcscrand(ComplexF64, 4) + println(m) + sm = m |> staticize + @test sm isa SPermMatrixCSC{4,ComplexF64} + @test sm.perm isa SVector + @test sm.vals isa SVector + @test sm.perm == m.perm + @test sm.vals == m.vals + dm = sm |> dynamicize + @test dm isa PermMatrixCSC{ComplexF64} + @test dm.perm isa Vector + @test dm.vals isa Vector + @test dm.perm == m.perm + @test dm.vals == m.vals + # csc m = sprand(ComplexF64, 4, 4, 0.5) sm = m |> staticize