Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Feb 25, 2024
1 parent 51e52ff commit f28ce2b
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 19 deletions.
14 changes: 10 additions & 4 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,16 @@ pmcscrand(n::Int) = pmcscrand(Float64, n)

Base.show(io::IO, ::MIME"text/plain", M::AbstractPermMatrix) = show(io, M)
function Base.show(io::IO, M::AbstractPermMatrix)
println(io, "PermMatrix")
for ((i, j), p) in IterNz(M)
print(io, "($i, $j) = $p")
i < length(M.perm) && println(io)
n = size(M, 1)
println(io, typeof(M))
nmax = 20
for (k, (i, j, p)) in enumerate(IterNz(M))
if k <= nmax || k > n-nmax
print(io, "($i, $j) = $p")
k < n && println(io)
elseif k == nmax+1
println(io, "...")
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Broadcast.broadcasted(
# specialize perm matrix
function _broadcast_perm_prod(A::AbstractPermMatrix, B::AbstractMatrix)
dest = similar(A, Base.promote_op(*, eltype(A), eltype(B)))
@inbounds for ((i, j), a) in IterNz(A)
@inbounds for (i, j, a) in IterNz(A)
dest[i, j] = a * B[i, j]
end
return dest
Expand All @@ -71,7 +71,7 @@ Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::Abstr
Diagonal(B)

function _broadcast_diag_perm_prod(A::Diagonal, B::AbstractPermMatrix)
Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1)))
Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1), 1:size(A, 2)))
end

Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::Diagonal) =
Expand Down
4 changes: 2 additions & 2 deletions src/iterate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ end
# PermMatrixCSC
function Base.iterate(it::IterNz{<:PermMatrixCSC})
0 == length(it) && return nothing
return ((@inbounds it.A.perm[1], 1), (@inbounds it.A.vals[1])), 1
return ((@inbounds it.A.perm[1]), 1, (@inbounds it.A.vals[1])), 1
end
function Base.iterate(it::IterNz{<:PermMatrixCSC}, state)
state == length(it) && return nothing
state += 1
return ((@inbounds it.A.perm[state], state), (@inbounds it.A.vals[state])), state
return ((@inbounds it.A.perm[state]), state, (@inbounds it.A.vals[state])), state
end

# AbstractMatrix
Expand Down
6 changes: 6 additions & 0 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ LinearAlgebra.kron(A::IMatrix{Ta}, B::IMatrix{Tb}) where {Ta<:Number,Tb<:Number}
LinearAlgebra.kron(A::IMatrix{<:Number}, B::Diagonal{<:Number}) = A.n == 1 ? B : Diagonal(orepeat(B.diag, A.n))
LinearAlgebra.kron(B::Diagonal{<:Number}, A::IMatrix) = A.n == 1 ? B : Diagonal(irepeat(B.diag, A.n))

####### diagonal kron ########
LinearAlgebra.kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B))
LinearAlgebra.kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrixCSC(A), B)
LinearAlgebra.kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrixCSC(A), B)
LinearAlgebra.kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrixCSC(B))

function LinearAlgebra.kron(A::AbstractMatrix{Tv}, B::IMatrix) where {Tv<:Number}
B.n == 1 && return A
mA, nA = size(A)
Expand Down
18 changes: 12 additions & 6 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function LinearAlgebra.mul!(Y::AbstractVector, A::AbstractPermMatrix, X::Abstrac
length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match permutation matrix A"))
length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match permutation matrix A"))

@inbounds for ((i, j), p) in IterNz(A)
@inbounds for (i, j, p) in IterNz(A)
Y[i] = p * X[j] * alpha + beta * Y[i]
end
return Y
Expand All @@ -85,22 +85,28 @@ function Base.:*(A::PermMatrixCSC{Ta}, D::Diagonal{Td}) where {Td,Ta}
end

# to self
function Base.:*(A::AbstractPermMatrix, B::AbstractPermMatrix)
function Base.:*(A::PermMatrix, B::PermMatrix)
@assert basetype(A) == basetype(B)
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
PermMatrix(B.perm[A.perm], A.vals .* view(B.vals, A.perm))
basetype(A)(B.perm[A.perm], A.vals .* view(B.vals, A.perm))
end

function Base.:*(A::PermMatrixCSC, B::PermMatrixCSC)
@assert basetype(A) == basetype(B)
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
basetype(A)(A.perm[B.perm], B.vals .* view(A.vals, B.perm))
end

# to matrix
function LinearAlgebra.:mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number)
function LinearAlgebra.mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number)
size(X, 1) == size(A, 2) || throw(DimensionMismatch())
AR = PermMatrix(A)
C .= C .* beta .+ AR.vals .* view(X,AR.perm,:) .* alpha
C .= C .* beta .+ AR.vals .* view(X, AR.perm, :) .* alpha
end
function LinearAlgebra.mul!(C::AbstractMatrix, X::AbstractMatrix, A::AbstractPermMatrix, alpha::Number, beta::Number)
size(X, 2) == size(A, 1) || throw(DimensionMismatch())
AC = PermMatrixCSC(A)
C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, perm) .* alpha
C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, AC.perm) .* alpha
end

# NOTE: this is just a temperory fix for v0.7. We should overload mul! in
Expand Down
107 changes: 107 additions & 0 deletions test/PermMatrixCSC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using Test, Random
import LuxurySparse: PermMatrixCSC, pmcscrand
import LuxurySparse
using SparseArrays: sprand, SparseMatrixCSC
using LinearAlgebra

Random.seed!(2)
p1 = PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5])
p2 = PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5])
#p3 = PermMatrix([4,1,2,3],[0.5, 0.4im, 0.3, 0.2])
p3 = pmcscrand(4)
sp = sprand(4, 4, 0.3)
v = [0.5, 0.3im, 0.2, 1.0]

@testset "basic" begin
@test p1 == copy(p1)
@test eltype(p1) == ComplexF64
@test eltype(p2) == Float64
@test eltype(p3) == Float64
@test size(p1) == (4, 4)
@test size(p3) == (4, 4)
@test size(p1, 1) == size(p1, 2) == 4
@test Matrix(p1) transpose([0.1 0 0 0; 0 0 0 0.2; 0 0.4im 0 0; 0 0 0.5 0])
p0 = similar(p1)
@test p0.perm == p1.perm
@test p0.perm !== p1.perm
@test p0.vals !== p1.vals
@test p1[2, 2] === 0.0im
@test p1[1, 1] === 0.1 + 0.0im
copyto!(p0, p1)
@test p0 == p1
end

@testset "linalg" begin
@test inv(p1) inv(Matrix(p1))
@test transpose(p1) transpose(Matrix(p1))
@test inv(p1) * p1 Matrix(I, 4, 4)
@test p1 * transpose(p1) diagm(0 => p1.vals[invperm(p1.perm)] .^ 2)
#@test p1*adjoint(p1) == diagm(0=>abs.(p1.vals).^2)
#@test all(isapprox.(adjoint(p3), transpose(conj(Matrix(p3)))))
@test p1 * p1' == diagm(0 => abs.(p1.vals[invperm(p1.perm)]) .^ 2)
@test all(isapprox.(p3', transpose(conj(Matrix(p3)))))
end

@testset "mul" begin
@test p3 * p2 SparseMatrixCSC(p3) * p2 Matrix(p3) * p2

# Multiply vector
@test p3 * v == Matrix(p3) * v
@test v' * p3 == v' * Matrix(p3)
@test vec(collect(1:4)' * p3) p3.perm .* p3.vals

# Diagonal matrices
Dv = Diagonal(v)
@test p3 * Dv == Matrix(p3) * Dv
@test Dv * p3 == Dv * Matrix(p3)
end

@testset "elementary" begin
@test all(isapprox.(conj(p1), conj(Matrix(p1))))
@test all(isapprox.(real(p1), real(Matrix(p1))))
@test all(isapprox.(imag(p1), imag(Matrix(p1))))
end

@testset "basicmath" begin
@test p1 * 2 == Matrix(p1) * 2
@test p1 / 2 == Matrix(p1) / 2
end

@testset "memorysafe" begin
@test p1 == PermMatrixCSC([1, 4, 2, 3], [0.1, 0.2, 0.4im, 0.5])
@test p2 == PermMatrixCSC([2, 1, 4, 3], [0.1, 0.2, 0.4, 0.5])
@test v == [0.5, 0.3im, 0.2, 1.0]
end

@testset "sparse" begin
Random.seed!(2)
pm = pmrand(10)
out = zeros(10, 10)
@test LuxurySparse.nnz(pm) == 10
@test LuxurySparse.findnz(pm)[3] == pm.vals
end

@testset "identity sparse" begin
p1 = Diagonal(randn(10))
@test LuxurySparse.nnz(p1) == 10
@test LuxurySparse.findnz(p1)[3] == p1.diag
end

@testset "setindex" begin
pm = PermMatrix([3, 2, 4, 1], [0.0, 0.0, 0.0, 0.0])
pm[3, 4] = 1.0
@test_throws BoundsError pm[3, 1] = 1.0
@test pm[3, 4] == 1.0
end

@testset "broadcast" begin
pm = PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3])
res = pm .* 3im
@test res == PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3] .* 3im) && res isa PermMatrix
end

@testset "fix dense-perm multiplication" begin
A = randn(ComplexF64, 4, 4)
pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3])
@test A * pm A * Matrix(pm)
end
16 changes: 15 additions & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using SparseArrays

@testset "broadcast *" begin

@testset "Diagonal .* $(nameof(typeof(M)))" for M in Any[pmrand(3)]
@testset "Diagonal .* $(nameof(typeof(M)))" for M in [[pmrand(3)]..., pmcscrand(3)]
M1 = Diagonal(rand(3))
out = M1 .* M
@test typeof(out) <: Diagonal
Expand All @@ -29,11 +29,21 @@ using SparseArrays
out = M .* M1
@test typeof(out) <: PermMatrix
@test out M .* Matrix(M1)

M1 = pmcscrand(3)
out = M1 .* M
@test typeof(out) <: PermMatrixCSC
@test out Matrix(M1) .* M

out = M .* M1
!(M isa PermMatrix) && @test typeof(out) <: PermMatrixCSC
@test out M .* Matrix(M1)
end

@testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[
rand(3, 3),
pmrand(3),
pmcscrand(3),
sprand(3, 3, 0.5),
]
eye = IMatrix(3)
Expand Down Expand Up @@ -77,6 +87,10 @@ end
M1 = pmrand(3)
@test M1 .- M Matrix(M1) .- M
@test M .- M1 M .- Matrix(M1)

M1 = pmcscrand(3)
@test M1 .- M Matrix(M1) .- M
@test M .- M1 M .- Matrix(M1)
end

@testset "IMatrix .* $(nameof(typeof(M)))" for M in Any[
Expand Down
1 change: 1 addition & 0 deletions test/iterate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test, LuxurySparse, SparseArrays, LinearAlgebra
@testset "iterate" begin
for M in Any[
pmrand(10),
pmcscrand(10),
Diagonal(randn(10)),
IMatrix(10),
randn(10, 10),
Expand Down
11 changes: 7 additions & 4 deletions test/kronecker.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test, Random, SparseArrays, LinearAlgebra
import LuxurySparse: IMatrix, PermMatrix
import LuxurySparse: IMatrix, PermMatrix, PermMatrixCSC, basetype, AbstractPermMatrix

@testset "kron" begin
Random.seed!(2)
Expand All @@ -8,12 +8,15 @@ import LuxurySparse: IMatrix, PermMatrix
sp = sprand(ComplexF64, 4, 4, 0.5)
ds = rand(ComplexF64, 4, 4)
pm = PermMatrix([2, 3, 4, 1], randn(4))
pm = PermMatrix([2, 3, 4, 1], randn(4))
pmc = PermMatrixCSC([2, 3, 4, 1], randn(4))
v = [0.5, 0.3im, 0.2, 1.0]
dv = Diagonal(v)

for source in Any[p1, sp, ds, dv, pm],
target in Any[p1, sp, ds, dv, pm]
for source in Any[p1, sp, ds, dv, pm, pmc],
target in Any[p1, sp, ds, dv, pm, pmc]
if source isa AbstractPermMatrix && target isa AbstractPermMatrix && basetype(source) != basetype(target)
continue
end
lres = kron(source, target)
rres = kron(target, source)
flres = kron(Matrix(source), Matrix(target))
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ end

@testset "PermMatrix" begin
include("PermMatrix.jl")
include("PermMatrixCSC.jl")
end

@testset "SparseMatrixCOO" begin
Expand Down

0 comments on commit f28ce2b

Please sign in to comment.