Skip to content

Vec * SparseMat & Mat * SparseMat support and other improvements #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ include("deprecated.jl")
include("generic.jl")
include("interface.jl")

export MKLSparseError

end # module
44 changes: 29 additions & 15 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,56 @@ end
return body
end

function mv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr, x::StridedVector{T}, beta::T, y::StridedVector{T}) where T
function mv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedVector{T}, beta::T, y::StridedVector{T}
) where T
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
mkl_call(Val{:mkl_sparse_T_mvI}(), typeof(A),
res = mkl_call(Val{:mkl_sparse_T_mvI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, x, beta, y)
check_status(res)
return y
end

function mm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr, x::StridedMatrix{T}, beta::T, y::StridedMatrix{T}) where T
function mm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedMatrix{T}, beta::T, y::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
columns = size(y, 2)
check_mat_op_sizes(y, A, transa, x, 'N'; dense_layout)
columns = size(y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldx = stride(x, 2)
ldy = stride(y, 2)
mkl_call(Val{:mkl_sparse_T_mmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
res = mkl_call(Val{:mkl_sparse_T_mmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, dense_layout, x, columns, ldx, beta, y, ldy)
check_status(res)
return y
end

function trsv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr, x::StridedVector{T}, y::StridedVector{T}) where T
function trsv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedVector{T}, y::StridedVector{T}
) where T
checksquare(A)
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
mkl_call(Val{:mkl_sparse_T_trsvI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, x, y)
res = mkl_call(Val{:mkl_sparse_T_trsvI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, x, y)
check_status(res)
return y
end

function trsm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr, x::StridedMatrix{T}, y::StridedMatrix{T}) where T
function trsm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedMatrix{T}, y::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
checksquare(A)
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
columns = size(y, 2)
check_mat_op_sizes(y, A, transa, x, 'N'; dense_layout)
columns = size(y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldx = stride(x, 2)
ldy = stride(y, 2)
mkl_call(Val{:mkl_sparse_T_trsmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
res = mkl_call(Val{:mkl_sparse_T_trsmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, dense_layout, x, columns, ldx, y, ldy)
check_status(res)
return y
end
131 changes: 102 additions & 29 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,137 @@ import LinearAlgebra: mul!, ldiv!

MKLSparseMat{T} = Union{SparseArrays.AbstractSparseMatrixCSC{T}, SparseMatrixCSR{T}, SparseMatrixCOO{T}}

SimpleOrSpecialMat{T, M} = Union{M, LowerTriangular{T,<:M}, UpperTriangular{T,<:M},
UnitLowerTriangular{T,<:M}, UnitUpperTriangular{T,<:M},
Symmetric{T,<:M}, Hermitian{T,<:M}}
SimpleOrSpecialOrAdjMat{T, M} = Union{SimpleOrSpecialMat{T, M},
Adjoint{T, <:SimpleOrSpecialMat{T, M}},
Transpose{T, <:SimpleOrSpecialMat{T, M}}}
SimpleOrAdjMat{T, M} = Union{M, Adjoint{T, <:M}, Transpose{T, <:M}}

SpecialMat{T, M} = Union{LowerTriangular{T,<:M}, UpperTriangular{T,<:M},
UnitLowerTriangular{T,<:M}, UnitUpperTriangular{T,<:M},
Symmetric{T,<:M}, Hermitian{T,<:M}}
SimpleOrSpecialMat{T, M} = Union{M, SpecialMat{T, <:M}}
SimpleOrSpecialOrAdjMat{T, M} = Union{SimpleOrAdjMat{T, <:SimpleOrSpecialMat{T, <:M}},
SimpleOrSpecialMat{T, <:SimpleOrAdjMat{T, <:M}}}

unwrapa(A::AbstractMatrix) = A
unwrapa(A::Union{LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
Symmetric, Hermitian}) = parent(A)
unwrapa(A::Union{Adjoint, Transpose}) = unwrapa(parent(A))
unwrapa(A::SpecialMat) = unwrapa(parent(A))

# returns a tuple of transa, matdescra and unwrapped A
describe_and_unwrap(A::AbstractMatrix) = ('N', matrixdescra(A), unwrapa(A))
describe_and_unwrap(A::Adjoint) = ('C', matrixdescra(A), unwrapa(parent(A)))
describe_and_unwrap(A::Transpose) = ('T', matrixdescra(A), unwrapa(parent(A)))
describe_and_unwrap(A::AbstractMatrix) = ('N', matrix_descr(A), unwrapa(A))
describe_and_unwrap(A::Adjoint) = ('C', matrix_descr(A), unwrapa(parent(A)))
describe_and_unwrap(A::Transpose) = ('T', matrix_descr(A), unwrapa(parent(A)))
describe_and_unwrap(A::LowerTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'N'), unwrapa(A))
describe_and_unwrap(A::UpperTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'N'), unwrapa(A))
describe_and_unwrap(A::UnitLowerTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'U'), unwrapa(A))
describe_and_unwrap(A::UnitUpperTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'U'), unwrapa(A))
describe_and_unwrap(A::Symmetric{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Transpose || (eltype(A) <: Real) ? 'N' : 'C', matrix_descr('S', A.uplo, 'N'), unwrapa(A))
describe_and_unwrap(A::Hermitian{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint || (eltype(A) <: Real) ? 'N' : 'T', matrix_descr('H', A.uplo, 'N'), unwrapa(A))

# 5-arg mul!()
function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: MKLSparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
# fix the strange behaviour of multipling adjoint vectors by triangular matrices
# looks like wrong the triangle is being used
if descrA.type == SPARSE_MATRIX_TYPE_TRIANGULAR && transA == 'C'
descrA = lazypermutedims(descrA)
end
mv!(transA, T(alpha), unwrapA, descrA, x, T(beta), y)
end

function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}, alpha::Number, beta::Number) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: MKLSparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
mm!(transA, T(alpha), unwrapA, descrA, B, T(beta), C)
end

# ColMajorRes = ColMajorMtx*SparseMatrixCSC is implemented via
# RowMajorRes = SparseMatrixCSR*RowMajorMtx Sparse MKL BLAS calls
# Switching the B layout from CSC to CSR is required, because MKLSparse
# does not support CSC 1-based multiplication with row-major matrices.
# Only CSC is supported as for the other sparse formats the combination
# of indexing, storage and dense layout would be unsupported,
# see https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-2/mkl-sparse-mm.html
# (one potential workaround is to temporarily switch to 0-based indexing)
function mul!(C::StridedMatrix{T}, A::StridedMatrix{T},
B::SimpleOrSpecialOrAdjMat{T, S}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: SparseArrays.AbstractSparseMatrixCSC{T}}
transB, descrB, unwrapB = describe_and_unwrap(B)
mm!(transB, T(alpha), lazypermutedims(unwrapB), lazypermutedims(descrB), A,
T(beta), C, dense_layout = SPARSE_LAYOUT_ROW_MAJOR)
end

# 3-arg mul!() calls 5-arg mul!()
mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(C, A, B, one(T), zero(T))
mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(y, A, x, one(T), zero(T))

mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(C, A, B, one(T), zero(T))
mul!(C::StridedMatrix{T}, A::StridedMatrix{T},
B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(C, A, B, one(T), zero(T))

# define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra
# redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1)
function ldiv!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function ldiv!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
trsv!(transA, alpha, unwrapA, descrA, x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function LinearAlgebra.ldiv!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
trsm!(transA, alpha, unwrapA, descrA, B, C)
end

function (*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}}
m, n = size(A)
y = Vector{T}(undef, m)
return mul!(y, A, x, one(T), zero(T))
if VERSION < v"1.10"
# stdlib v1.9 does not provide these methods

(*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(Vector{T}(undef, size(A, 1)), A, x)

(*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B)

# xᵀ * B = (Bᵀ * x)ᵀ
(*)(x::Transpose{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
transpose(mul!(similar(x, size(B, 2)), transpose(B), parent(x)))

# xᴴ * B = (Bᴴ * x)ᴴ
(*)(x::Adjoint{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
adjoint(mul!(similar(x, size(B, 2)), adjoint(B), parent(x)))

end # if VERSION < v"1.10"

(*)(A::StridedMatrix{T}, B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B)

# stdlib does not provide these methods for complex types

# xᴴ * Bᵀ = (Bᵀᴴ * x)ᴴ
function (*)(x::Adjoint{T, <:StridedVector{T}}, B::Transpose{T, <:SimpleOrSpecialMat{T, S}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}}
transB, descrB, unwrapB = describe_and_unwrap(parent(B))
y = similar(x, size(B, 2))
adjoint(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x),
zero(T), y))
end

function (*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}}
m, k = size(A)
p, n = size(B)
C = Matrix{T}(undef, m, n)
return mul!(C, A, B, one(T), zero(T))
# xᵀ * Bᴴ = (Bᵀᴴ * x)ᵀ
function (*)(x::Transpose{T, <:StridedVector{T}}, B::Adjoint{T, <:SimpleOrSpecialMat{T, S}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}}
transB, descrB, unwrapB = describe_and_unwrap(parent(B))
y = similar(x, size(B, 2))
transpose(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x),
zero(T), y))
end

function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}}
Expand Down
Loading