Skip to content

Commit 9b22cd4

Browse files
committed
Specialize LinearAlgebra.BLAS.dot for strided vectors of floats.
Fixes #37767.
1 parent 76698ea commit 9b22cd4

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,50 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
337337
end
338338
end
339339

340+
@inline function _dot_length_check(x,y)
341+
n = length(x)
342+
if n != length(y)
343+
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
344+
end
345+
n
346+
end
347+
348+
for (elty, f) in ((Float32, :dot), (Float64, :dot),
349+
(ComplexF32, :dotc), (ComplexF64, :dotc),
350+
(ComplexF32, :dotu), (ComplexF64, :dotu))
351+
@eval begin
352+
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
353+
n = _dot_length_check(x,y)
354+
$f(n, x, 1, y, 1)
355+
end
356+
357+
function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
358+
n = _dot_length_check(x,y)
359+
xstride = stride(x,1)
360+
ystride = stride(y,1)
361+
x_delta = xstride < 0 ? n : 1
362+
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
363+
end
364+
365+
function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
366+
n = _dot_length_check(x,y)
367+
xstride = stride(x,1)
368+
ystride = stride(y,1)
369+
y_delta = ystride < 0 ? n : 1
370+
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
371+
end
372+
373+
function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
374+
n = _dot_length_check(x,y)
375+
xstride = stride(x,1)
376+
ystride = stride(y,1)
377+
x_delta = xstride < 0 ? n : 1
378+
y_delta = ystride < 0 ? n : 1
379+
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
380+
end
381+
end
382+
end
383+
340384
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
341385
require_one_based_indexing(DX, DY)
342386
n = length(DX)

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ end
205205
@test *(Asub, adjoint(Asub)) == *(Aref, adjoint(Aref))
206206
end
207207

208+
@testset "dot product of subarrays of vectors (floats, negative stride, issue #37767)" begin
209+
for T in (Float32, Float64, ComplexF32, ComplexF64)
210+
a = Vector{T}(3:2:7)
211+
b = Vector{T}(1:10)
212+
v = view(b,7:-2:3)
213+
@test dot(a,Vector(v)) 67.0
214+
@test dot(a,v) 67.0
215+
@test dot(v,a) 67.0
216+
@test dot(Vector(v),Vector(v)) 83.0
217+
@test dot(v,v) 83.0
218+
end
219+
end
220+
208221
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32,Float64)
209222
for T2 in (Float32,Float64)
210223
for arg1_real in (true,false)

0 commit comments

Comments
 (0)