Skip to content

Commit 501bc34

Browse files
committed
Add negative stride support to BLAS Level 1/2 functions
1 parent d00d457 commit 501bc34

File tree

2 files changed

+124
-159
lines changed

2 files changed

+124
-159
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 65 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ end
159159

160160

161161
# Level 1
162+
isdense(x) = x isa DenseArray
163+
isdense(x::Base.FastContiguousSubArray) = isdense(parent(x))
164+
isdense(x::Base.ReshapedArray) = isdense(parent(x))
165+
isdense(x::Base.ReinterpretArray) = isdense(parent(x))
166+
@inline function ptrst1(x::AbstractArray)
167+
isdense(x) && return pointer(x), 1 # simpify runtime check when possibe
168+
ndims(x) == 1 || strides(x) == Base.size_to_strides(strides(x, 1), size(x)...) ||
169+
throw(ArgumentError("only support vector like inputs"))
170+
st = stride(x, 1)
171+
ptr = st >= 0 ? pointer(x) : pointer(x, lastindex(x))
172+
ptr, st
173+
end
162174
## copy
163175

164176
"""
@@ -249,7 +261,10 @@ for (fname, elty) in ((:dscal_,:Float64),
249261
DX
250262
end
251263

252-
scal!(DA::$elty, DX::AbstractArray{$elty}) = scal!(length(DX),DA,DX,stride(DX,1))
264+
scal!(DA::$elty, DX::AbstractArray{$elty}) = let (p, st) = ptrst1(DX)
265+
GC.@preserve DX scal!(length(DX), DA, p, abs(st))
266+
DX
267+
end
253268
end
254269
end
255270
scal(n, DA, DX, incx) = scal!(n, DA, copy(DX), incx)
@@ -353,75 +368,18 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
353368
end
354369
end
355370

356-
@inline function _dot_length_check(x,y)
357-
n = length(x)
358-
if n != length(y)
359-
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
360-
end
361-
n
362-
end
363-
364371
for (elty, f) in ((Float32, :dot), (Float64, :dot),
365372
(ComplexF32, :dotc), (ComplexF64, :dotc),
366373
(ComplexF32, :dotu), (ComplexF64, :dotu))
367374
@eval begin
368-
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
369-
n = _dot_length_check(x,y)
370-
$f(n, x, 1, y, 1)
371-
end
372-
373-
function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
374-
n = _dot_length_check(x,y)
375-
xstride = stride(x,1)
376-
ystride = stride(y,1)
377-
x_delta = xstride < 0 ? n : 1
378-
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
379-
end
380-
381-
function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
382-
n = _dot_length_check(x,y)
383-
xstride = stride(x,1)
384-
ystride = stride(y,1)
385-
y_delta = ystride < 0 ? n : 1
386-
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
387-
end
388-
389-
function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
390-
n = _dot_length_check(x,y)
391-
xstride = stride(x,1)
392-
ystride = stride(y,1)
393-
x_delta = xstride < 0 ? n : 1
394-
y_delta = ystride < 0 ? n : 1
395-
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
375+
function $f(x::AbstractVector{$elty}, y::AbstractVector{$elty})
376+
n, m = length(x), length(y)
377+
n == m || throw(DimensionMismatch("dot product arguments have lengths $n and $m"))
378+
GC.@preserve x y $f(n, ptrst1(x)..., ptrst1(y)...)
396379
end
397380
end
398381
end
399382

400-
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
401-
require_one_based_indexing(DX, DY)
402-
n = length(DX)
403-
if n != length(DY)
404-
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
405-
end
406-
return dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
407-
end
408-
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
409-
require_one_based_indexing(DX, DY)
410-
n = length(DX)
411-
if n != length(DY)
412-
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
413-
end
414-
return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
415-
end
416-
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
417-
require_one_based_indexing(DX, DY)
418-
n = length(DX)
419-
if n != length(DY)
420-
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
421-
end
422-
return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
423-
end
424-
425383
## nrm2
426384

427385
"""
@@ -453,7 +411,10 @@ for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
453411
end
454412
end
455413
end
456-
nrm2(x::Union{AbstractVector,DenseArray}) = nrm2(length(x), x, stride1(x))
414+
# openblas returns 0 for negative stride
415+
nrm2(x::Union{AbstractArray}) = let (p, st) = ptrst1(x)
416+
GC.@preserve x nrm2(length(x), p, abs(st))
417+
end
457418

458419
## asum
459420

@@ -490,8 +451,9 @@ for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
490451
end
491452
end
492453
end
493-
asum(x::Union{AbstractVector,DenseArray}) = asum(length(x), x, stride1(x))
494-
454+
asum(x::Union{AbstractArray}) = let (p, st) = ptrst1(x)
455+
GC.@preserve x asum(length(x), p, abs(st))
456+
end
495457
## axpy
496458

497459
"""
@@ -538,7 +500,8 @@ function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union
538500
if length(x) != length(y)
539501
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
540502
end
541-
return axpy!(length(x), convert(T,alpha), x, stride(x, 1), y, stride(y, 1))
503+
GC.@preserve x y axpy!(length(x), convert(T,alpha), ptrst1(x)..., ptrst1(y)...)
504+
y
542505
end
543506

544507
function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange{Ti}},
@@ -555,9 +518,9 @@ function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange
555518
GC.@preserve x y axpy!(
556519
length(rx),
557520
convert(T, alpha),
558-
pointer(x) + (first(rx) - 1)*sizeof(T),
521+
pointer(x, minimum(rx)),
559522
step(rx),
560-
pointer(y) + (first(ry) - 1)*sizeof(T),
523+
pointer(y, minimum(ry)),
561524
step(ry))
562525

563526
return y
@@ -609,7 +572,8 @@ function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::
609572
if length(x) != length(y)
610573
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
611574
end
612-
return axpby!(length(x), convert(T, alpha), x, stride(x, 1), convert(T, beta), y, stride(y, 1))
575+
GC.@preserve x y axpby!(length(x), convert(T, alpha), ptrst1(x)..., convert(T, beta), ptrst1(y)...)
576+
y
613577
end
614578

615579
## iamax
@@ -666,10 +630,7 @@ for (fname, elty) in ((:dgemv_,:Float64),
666630
chkstride1(A)
667631
lda = stride(A,2)
668632
lda >= max(1, size(A,1)) || error("`stride(A,2)` must be at least `max(1, size(A,1))`")
669-
sX = stride(X,1)
670-
pX = pointer(X, sX > 0 ? firstindex(X) : lastindex(X))
671-
sY = stride(Y,1)
672-
pY = pointer(Y, sY > 0 ? firstindex(Y) : lastindex(Y))
633+
(pX, sX), (pY, sY) = ptrst1(X), ptrst1(Y)
673634
GC.@preserve X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
674635
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
675636
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
@@ -750,14 +711,15 @@ for (fname, elty) in ((:dgbmv_,:Float64),
750711
y::AbstractVector{$elty})
751712
require_one_based_indexing(A, x, y)
752713
chkstride1(A)
753-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
714+
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
715+
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
754716
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt},
755717
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
756718
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
757719
Ref{BlasInt}, Clong),
758720
trans, m, size(A,2), kl,
759721
ku, alpha, A, max(1,stride(A,2)),
760-
x, stride(x,1), beta, y, stride(y,1), 1)
722+
px, stx, beta, py, sty, 1)
761723
y
762724
end
763725
function gbmv(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -810,13 +772,14 @@ for (fname, elty, lib) in ((:dsymv_,:Float64,libblastrampoline),
810772
throw(DimensionMismatch("A has size $(size(A)), and y has length $(length(y))"))
811773
end
812774
chkstride1(A)
813-
ccall((@blasfunc($fname), $lib), Cvoid,
775+
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
776+
GC.@preserve x y ccall((@blasfunc($fname), $lib), Cvoid,
814777
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
815778
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty},
816779
Ptr{$elty}, Ref{BlasInt}, Clong),
817780
uplo, n, alpha, A,
818-
max(1,stride(A,2)), x, stride(x,1), beta,
819-
y, stride(y,1), 1)
781+
max(1,stride(A,2)), px, stx, beta,
782+
py, sty, 1)
820783
y
821784
end
822785
function symv(uplo::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -872,15 +835,14 @@ for (fname, elty) in ((:zhemv_,:ComplexF64),
872835
end
873836
chkstride1(A)
874837
lda = max(1, stride(A, 2))
875-
incx = stride(x, 1)
876-
incy = stride(y, 1)
877-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
838+
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
839+
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
878840
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
879841
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty},
880842
Ptr{$elty}, Ref{BlasInt}, Clong),
881843
uplo, n, α, A,
882-
lda, x, incx, β,
883-
y, incy, 1)
844+
lda, px, stx, β,
845+
py, sty, 1)
884846
y
885847
end
886848
function hemv(uplo::AbstractChar, α::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -968,7 +930,8 @@ function hpmv!(uplo::AbstractChar,
968930
if 2*length(AP) < N*(N + 1)
969931
throw(DimensionMismatch("Packed Hermitian matrix A has size smaller than length(x) = $(N)."))
970932
end
971-
return hpmv!(uplo, N, convert(T, α), AP, x, stride(x, 1), convert(T, β), y, stride(y, 1))
933+
GC.@preserve x y hpmv!(uplo, N, convert(T, α), AP, ptrst1(x)..., convert(T, β), ptrst1(y)...)
934+
y
972935
end
973936

974937
"""
@@ -1009,13 +972,14 @@ for (fname, elty) in ((:dsbmv_,:Float64),
1009972
function sbmv!(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
1010973
require_one_based_indexing(A, x, y)
1011974
chkstride1(A)
1012-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
975+
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
976+
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
1013977
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
1014978
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
1015979
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
1016980
uplo, size(A,2), k, alpha,
1017-
A, max(1,stride(A,2)), x, stride(x,1),
1018-
beta, y, stride(y,1), 1)
981+
A, max(1,stride(A,2)), px, stx,
982+
beta, py, sty, 1)
1019983
y
1020984
end
1021985
function sbmv(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -1118,7 +1082,8 @@ function spmv!(uplo::AbstractChar,
11181082
if 2*length(AP) < N*(N + 1)
11191083
throw(DimensionMismatch("Packed symmetric matrix A has size smaller than length(x) = $(N)."))
11201084
end
1121-
return spmv!(uplo, N, convert(T, α), AP, x, stride(x, 1), convert(T, β), y, stride(y, 1))
1085+
GC.@preserve x y spmv!(uplo, N, convert(T, α), AP, ptrst1(x)..., convert(T, β), ptrst1(y)...)
1086+
y
11221087
end
11231088

11241089
"""
@@ -1159,13 +1124,14 @@ for (fname, elty) in ((:zhbmv_,:ComplexF64),
11591124
function hbmv!(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
11601125
require_one_based_indexing(A, x, y)
11611126
chkstride1(A)
1162-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
1127+
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
1128+
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
11631129
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
11641130
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
11651131
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
11661132
uplo, size(A,2), k, alpha,
1167-
A, max(1,stride(A,2)), x, stride(x,1),
1168-
beta, y, stride(y,1), 1)
1133+
A, max(1,stride(A,2)), px, stx,
1134+
beta, py, sty, 1)
11691135
y
11701136
end
11711137
function hbmv(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -1219,12 +1185,13 @@ for (fname, elty) in ((:dtrmv_,:Float64),
12191185
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
12201186
end
12211187
chkstride1(A)
1222-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
1188+
px, stx = ptrst1(x)
1189+
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
12231190
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
12241191
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
12251192
Clong, Clong, Clong),
12261193
uplo, trans, diag, n,
1227-
A, max(1,stride(A,2)), x, max(1,stride(x, 1)), 1, 1, 1)
1194+
A, max(1,stride(A,2)), px, stx, 1, 1, 1)
12281195
x
12291196
end
12301197
function trmv(uplo::AbstractChar, trans::AbstractChar, diag::AbstractChar, A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -1274,12 +1241,13 @@ for (fname, elty) in ((:dtrsv_,:Float64),
12741241
throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
12751242
end
12761243
chkstride1(A)
1277-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
1244+
px, stx = ptrst1(x)
1245+
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
12781246
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
12791247
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
12801248
Clong, Clong, Clong),
12811249
uplo, trans, diag, n,
1282-
A, max(1,stride(A,2)), x, stride(x, 1), 1, 1, 1)
1250+
A, max(1,stride(A,2)), px, stx, 1, 1, 1)
12831251
x
12841252
end
12851253
function trsv(uplo::AbstractChar, trans::AbstractChar, diag::AbstractChar, A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
@@ -1993,9 +1961,9 @@ function copyto!(dest::Array{T}, rdest::Union{UnitRange{Ti},AbstractRange{Ti}},
19931961
end
19941962
GC.@preserve src dest BLAS.blascopy!(
19951963
length(rsrc),
1996-
pointer(src) + (first(rsrc) - 1) * sizeof(T),
1964+
pointer(src, minimum(rsrc)),
19971965
step(rsrc),
1998-
pointer(dest) + (first(rdest) - 1) * sizeof(T),
1966+
pointer(dest, minimum(rdest)),
19991967
step(rdest))
20001968

20011969
return dest

0 commit comments

Comments
 (0)