Skip to content

Commit 99c99b4

Browse files
authored
Specialize 3-arg dot for sparse self-adjoint matrices (#398)
1 parent cb10c1e commit 99c99b4

File tree

2 files changed

+93
-7
lines changed

2 files changed

+93
-7
lines changed

src/linalg.jl

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular,
4-
checksquare, sym_uplo
4+
RealHermSymComplexHerm, checksquare, sym_uplo
55
using Random: rand!
66

77
# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -900,17 +900,96 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
900900
C
901901
end
902902

903-
# row range up to and including diagonal
904-
function nzrangeup(A, i)
903+
# row range up to (and including if excl=false) diagonal
904+
function nzrangeup(A, i, excl=false)
905905
r = nzrange(A, i); r1 = r.start; r2 = r.stop
906906
rv = rowvals(A)
907-
@inbounds r2 < r1 || rv[r2] <= i ? r : r1:searchsortedlast(rv, i, r1, r2, Forward)
907+
@inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:searchsortedlast(rv, i - excl, r1, r2, Forward)
908908
end
909-
# row range from diagonal (included) to end
910-
function nzrangelo(A, i)
909+
# row range from diagonal (included if excl=false) to end
910+
function nzrangelo(A, i, excl=false)
911911
r = nzrange(A, i); r1 = r.start; r2 = r.stop
912912
rv = rowvals(A)
913-
@inbounds r2 < r1 || rv[r1] >= i ? r : searchsortedfirst(rv, i, r1, r2, Forward):r2
913+
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : searchsortedfirst(rv, i + excl, r1, r2, Forward):r2
914+
end
915+
916+
dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) =
917+
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
918+
function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, rangefun::Function, diagop::Function, odiagop::Function)
919+
require_one_based_indexing(x, y)
920+
m, n = size(A)
921+
(length(x) == m && n == length(y)) || throw(DimensionMismatch())
922+
if iszero(m) || iszero(n)
923+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
924+
end
925+
T = promote_type(eltype(x), eltype(A), eltype(y))
926+
r = zero(T)
927+
rvals = getrowval(A)
928+
nzvals = getnzval(A)
929+
@inbounds for col in 1:n
930+
ycol = y[col]
931+
xcol = x[col]
932+
if _isnotzero(ycol) && _isnotzero(xcol)
933+
for k in rangefun(A, col)
934+
i = rvals[k]
935+
Aij = nzvals[k]
936+
if i != col
937+
r += dot(x[i], Aij, ycol)
938+
r += dot(xcol, odiagop(Aij), y[i])
939+
else
940+
r += dot(x[i], diagop(Aij), ycol)
941+
end
942+
end
943+
end
944+
end
945+
return r
946+
end
947+
dot(x::SparseVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::SparseVector) =
948+
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real)
949+
function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rangefun::Function, diagop::Function)
950+
m, n = size(A)
951+
length(x) == m && n == length(y) || throw(DimensionMismatch())
952+
if iszero(m) || iszero(n)
953+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
954+
end
955+
r = zero(promote_type(eltype(x), eltype(A), eltype(y)))
956+
xnzind = nonzeroinds(x)
957+
xnzval = nonzeros(x)
958+
ynzind = nonzeroinds(y)
959+
ynzval = nonzeros(y)
960+
Arowval = getrowval(A)
961+
Anzval = getnzval(A)
962+
Acolptr = getcolptr(A)
963+
isempty(Arowval) && return r
964+
# plain triangle without diagonal
965+
for (yi, yv) in zip(ynzind, ynzval)
966+
A_ptr_lo = first(rangefun(A, yi, true))
967+
A_ptr_hi = last(rangefun(A, yi, true))
968+
if A_ptr_lo <= A_ptr_hi
969+
# dot is conjugated in the first argument, so double conjugate a's
970+
r += dot(_spdot((x, a) -> a'x, 1, length(xnzind), xnzind, xnzval,
971+
A_ptr_lo, A_ptr_hi, Arowval, Anzval), yv)
972+
end
973+
end
974+
# view triangle without diagonal
975+
for (xi, xv) in zip(xnzind, xnzval)
976+
A_ptr_lo = first(rangefun(A, xi, true))
977+
A_ptr_hi = last(rangefun(A, xi, true))
978+
if A_ptr_lo <= A_ptr_hi
979+
r += dot(xv, _spdot((a, y) -> a'y, A_ptr_lo, A_ptr_hi, Arowval, Anzval,
980+
1, length(ynzind), ynzind, ynzval))
981+
end
982+
end
983+
# diagonal
984+
for i in 1:m
985+
r1 = Int(Acolptr[i])
986+
r2 = Int(Acolptr[i+1]-1)
987+
r1 > r2 && continue
988+
r1 = searchsortedfirst(Arowval, i, r1, r2, Forward)
989+
((r1 > r2) || (Arowval[r1] != i)) && continue
990+
r += dot(x[i], diagop(Anzval[r1]), y[i])
991+
end
992+
r
914993
end
915994
## end of symmetric/Hermitian
916995

test/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,13 @@ end
811811
@test dot(x, A, y) dot(Vector(x), A, Vector(y)) (Vector(x)' * Matrix(A)) * Vector(y)
812812
@test dot(x, A, y) dot(x, Av, y)
813813
end
814+
815+
for (T, trans) in ((Float64, Symmetric), (ComplexF64, Hermitian)), uplo in (:U, :L)
816+
B = sprandn(T, 10, 10, 0.2)
817+
x = sprandn(T, 10, 0.4)
818+
S = trans(B'B, uplo)
819+
@test dot(x, S, x) dot(Vector(x), S, Vector(x)) dot(Vector(x), Matrix(S), Vector(x))
820+
end
814821
end
815822

816823
@testset "conversion to special LinearAlgebra types" begin

0 commit comments

Comments
 (0)