|
1 | 1 | # This file is a part of Julia. License is MIT: https://julialang.org/license
|
2 | 2 |
|
3 | 3 | using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular,
|
4 |
| - checksquare, sym_uplo |
| 4 | + RealHermSymComplexHerm, checksquare, sym_uplo |
5 | 5 | using Random: rand!
|
6 | 6 |
|
7 | 7 | # In matrix-vector multiplication, the correct orientation of the vector is assumed.
|
@@ -900,17 +900,96 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
|
900 | 900 | C
|
901 | 901 | end
|
902 | 902 |
|
903 |
| -# row range up to and including diagonal |
904 |
| -function nzrangeup(A, i) |
| 903 | +# row range up to (and including if excl=false) diagonal |
| 904 | +function nzrangeup(A, i, excl=false) |
905 | 905 | r = nzrange(A, i); r1 = r.start; r2 = r.stop
|
906 | 906 | rv = rowvals(A)
|
907 |
| - @inbounds r2 < r1 || rv[r2] <= i ? r : r1:searchsortedlast(rv, i, r1, r2, Forward) |
| 907 | + @inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:searchsortedlast(rv, i - excl, r1, r2, Forward) |
908 | 908 | end
|
909 |
| -# row range from diagonal (included) to end |
910 |
| -function nzrangelo(A, i) |
| 909 | +# row range from diagonal (included if excl=false) to end |
| 910 | +function nzrangelo(A, i, excl=false) |
911 | 911 | r = nzrange(A, i); r1 = r.start; r2 = r.stop
|
912 | 912 | rv = rowvals(A)
|
913 |
| - @inbounds r2 < r1 || rv[r1] >= i ? r : searchsortedfirst(rv, i, r1, r2, Forward):r2 |
| 913 | + @inbounds r2 < r1 || rv[r1] >= i + excl ? r : searchsortedfirst(rv, i + excl, r1, r2, Forward):r2 |
| 914 | +end |
| 915 | + |
| 916 | +dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) = |
| 917 | + _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint) |
| 918 | +function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, rangefun::Function, diagop::Function, odiagop::Function) |
| 919 | + require_one_based_indexing(x, y) |
| 920 | + m, n = size(A) |
| 921 | + (length(x) == m && n == length(y)) || throw(DimensionMismatch()) |
| 922 | + if iszero(m) || iszero(n) |
| 923 | + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) |
| 924 | + end |
| 925 | + T = promote_type(eltype(x), eltype(A), eltype(y)) |
| 926 | + r = zero(T) |
| 927 | + rvals = getrowval(A) |
| 928 | + nzvals = getnzval(A) |
| 929 | + @inbounds for col in 1:n |
| 930 | + ycol = y[col] |
| 931 | + xcol = x[col] |
| 932 | + if _isnotzero(ycol) && _isnotzero(xcol) |
| 933 | + for k in rangefun(A, col) |
| 934 | + i = rvals[k] |
| 935 | + Aij = nzvals[k] |
| 936 | + if i != col |
| 937 | + r += dot(x[i], Aij, ycol) |
| 938 | + r += dot(xcol, odiagop(Aij), y[i]) |
| 939 | + else |
| 940 | + r += dot(x[i], diagop(Aij), ycol) |
| 941 | + end |
| 942 | + end |
| 943 | + end |
| 944 | + end |
| 945 | + return r |
| 946 | +end |
| 947 | +dot(x::SparseVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::SparseVector) = |
| 948 | + _dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real) |
| 949 | +function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rangefun::Function, diagop::Function) |
| 950 | + m, n = size(A) |
| 951 | + length(x) == m && n == length(y) || throw(DimensionMismatch()) |
| 952 | + if iszero(m) || iszero(n) |
| 953 | + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) |
| 954 | + end |
| 955 | + r = zero(promote_type(eltype(x), eltype(A), eltype(y))) |
| 956 | + xnzind = nonzeroinds(x) |
| 957 | + xnzval = nonzeros(x) |
| 958 | + ynzind = nonzeroinds(y) |
| 959 | + ynzval = nonzeros(y) |
| 960 | + Arowval = getrowval(A) |
| 961 | + Anzval = getnzval(A) |
| 962 | + Acolptr = getcolptr(A) |
| 963 | + isempty(Arowval) && return r |
| 964 | + # plain triangle without diagonal |
| 965 | + for (yi, yv) in zip(ynzind, ynzval) |
| 966 | + A_ptr_lo = first(rangefun(A, yi, true)) |
| 967 | + A_ptr_hi = last(rangefun(A, yi, true)) |
| 968 | + if A_ptr_lo <= A_ptr_hi |
| 969 | + # dot is conjugated in the first argument, so double conjugate a's |
| 970 | + r += dot(_spdot((x, a) -> a'x, 1, length(xnzind), xnzind, xnzval, |
| 971 | + A_ptr_lo, A_ptr_hi, Arowval, Anzval), yv) |
| 972 | + end |
| 973 | + end |
| 974 | + # view triangle without diagonal |
| 975 | + for (xi, xv) in zip(xnzind, xnzval) |
| 976 | + A_ptr_lo = first(rangefun(A, xi, true)) |
| 977 | + A_ptr_hi = last(rangefun(A, xi, true)) |
| 978 | + if A_ptr_lo <= A_ptr_hi |
| 979 | + r += dot(xv, _spdot((a, y) -> a'y, A_ptr_lo, A_ptr_hi, Arowval, Anzval, |
| 980 | + 1, length(ynzind), ynzind, ynzval)) |
| 981 | + end |
| 982 | + end |
| 983 | + # diagonal |
| 984 | + for i in 1:m |
| 985 | + r1 = Int(Acolptr[i]) |
| 986 | + r2 = Int(Acolptr[i+1]-1) |
| 987 | + r1 > r2 && continue |
| 988 | + r1 = searchsortedfirst(Arowval, i, r1, r2, Forward) |
| 989 | + ((r1 > r2) || (Arowval[r1] != i)) && continue |
| 990 | + r += dot(x[i], diagop(Anzval[r1]), y[i]) |
| 991 | + end |
| 992 | + r |
914 | 993 | end
|
915 | 994 | ## end of symmetric/Hermitian
|
916 | 995 |
|
|
0 commit comments