Skip to content

Commit 6f7daea

Browse files
committed
Add zero stride-check to LinearAlgebra.gemv!
Also call BLAS for negative `lda` (if possible)
1 parent 6d4f8b9 commit 6f7daea

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
498498
nA == 0 && return _rmul_or_fill!(y, β)
499499
alpha, beta = promote(α, β, zero(T))
500500
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
501-
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
501+
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
502+
!iszero(stride(x, 1)) # We only check input's stride here.
502503
return BLAS.gemv!(tA, alpha, A, x, beta, y)
503504
else
504505
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
@@ -516,8 +517,9 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
516517
nA == 0 && return _rmul_or_fill!(y, β)
517518
alpha, beta = promote(α, β, zero(T))
518519
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
519-
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) &&
520-
stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y`
520+
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
521+
stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
522+
!iszero(stride(x, 1))
521523
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
522524
return y
523525
else
@@ -535,7 +537,9 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
535537
mA == 0 && return y
536538
nA == 0 && return _rmul_or_fill!(y, β)
537539
alpha, beta = promote(α, β, zero(T))
538-
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
540+
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
541+
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
542+
!iszero(stride(x, 1))
539543
xfl = reinterpret(reshape, T, x) # Use reshape here.
540544
yfl = reinterpret(reshape, T, y)
541545
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,15 @@ end
297297
end
298298
end
299299

300+
@testset "matrix x vector with negtive lda or 0 stride" for T in (Float32, Float64)
301+
for TA in (T, complex(T)), TB in (T, complex(T))
302+
A = view(randn(TA, 10, 10), 1:10, 10:-1:1) # negative lda
303+
v = view([randn(TB)], 1 .+ 0(1:10)) # 0 stride
304+
Ad, vd = copy(A), copy(v)
305+
@test Ad * vd A * vd Ad * v A * v
306+
end
307+
end
308+
300309
@testset "issue #15286" begin
301310
A = reshape(map(Float64, 1:20), 5, 4)
302311
C = zeros(8, 8)

0 commit comments

Comments
 (0)