Skip to content

Commit fcb943d

Browse files
committed
Extend strides(::ReshapedArray) with non-contiguous strided parent
Use `Base.merge_adjacent_dim` to perform vector layout check before BLAS call.
1 parent 98b4b06 commit fcb943d

File tree

4 files changed

+98
-23
lines changed

4 files changed

+98
-23
lines changed

base/reshapedarray.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,51 @@ unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex
294294
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)
295295

296296

297-
_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
298-
_checkcontiguous(::Type{Bool}, A::Array) = true
297+
_checkcontiguous(::Type{Bool}, A::AbstractArray) = false
298+
# `strides(A::DenseArray)` calls `size_to_strides` by default.
299+
# Thus it's OK to assume all `DenseArray`s are contiguously stored.
300+
_checkcontiguous(::Type{Bool}, A::DenseArray) = true
299301
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
300302
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))
301303

302304
function strides(a::ReshapedArray)
303-
# We can handle non-contiguous parent if it's a StridedVector
304-
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
305-
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
306-
size_to_strides(1, size(a)...)
305+
_checkcontiguous(Bool, a) && return size_to_strides(1, size(a)...)
306+
apsz::Dims = size(a.parent)
307+
apst::Dims = strides(a.parent)
308+
msz, mst, n = merge_adjacent_dim(apsz, apst) # Try to perform "lazy" reshape
309+
n == ndims(a.parent) && return size_to_strides(mst, size(a)...) # Parent is stridevector like
310+
return _reshaped_strides(size(a), 1, msz, mst, n, apsz, apst)
311+
end
312+
313+
function _reshaped_strides(::Dims{0}, reshaped::Int, msz::Int, ::Int, ::Int, ::Dims, ::Dims)
314+
reshaped == msz || throw(ArgumentError("Input is not strided."))
315+
()
316+
end
317+
function _reshaped_strides(sz::Dims, reshaped::Int, msz::Int, mst::Int, n::Int, apsz::Dims, apst::Dims)
318+
st = reshaped * mst
319+
reshaped = reshaped * sz[1]
320+
if length(sz) > 1 && reshaped == msz && sz[2] != 1
321+
msz, mst, n = merge_adjacent_dim(apsz, apst, n + 1)
322+
reshaped = 1
323+
end
324+
sts = _reshaped_strides(tail(sz), reshaped, msz, mst, n, apsz, apst)
325+
return (st, sts...)
326+
end
327+
328+
merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0
329+
merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1
330+
function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N}
331+
sz, st = apsz[n], apst[n]
332+
while n < N
333+
szₙ, stₙ = apsz[n+1], apst[n+1]
334+
if sz == 1
335+
sz, st = szₙ, stₙ
336+
elseif stₙ == st * sz || szₙ == 1
337+
sz *= szₙ
338+
else
339+
break
340+
end
341+
n += 1
342+
end
343+
return sz, st, n
307344
end

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,19 @@ end
148148
# Level 1
149149
# A help function to pick the pointer and inc for 1d like inputs.
150150
@inline function vec_pointer_stride(x::AbstractArray, stride0check = nothing)
151-
isdense(x) && return pointer(x), 1 # simpify runtime check when possibe
152-
ndims(x) == 1 || strides(x) == Base.size_to_strides(stride(x, 1), size(x)...) ||
153-
throw(ArgumentError("only support vector like inputs"))
154-
st = stride(x, 1)
151+
Base._checkcontiguous(Bool, x) && return pointer(x), 1 # simpify runtime check when possibe
152+
st, ptr = checkedstride(x), pointer(x)
155153
isnothing(stride0check) || (st == 0 && throw(stride0check))
156-
ptr = st > 0 ? pointer(x) : pointer(x, lastindex(x))
154+
ptr += min(st, 0) * sizeof(eltype(x)) * (length(x) - 1)
157155
ptr, st
158156
end
159-
isdense(x) = x isa DenseArray
160-
isdense(x::Base.FastContiguousSubArray) = isdense(parent(x))
161-
isdense(x::Base.ReshapedArray) = isdense(parent(x))
162-
isdense(x::Base.ReinterpretArray) = isdense(parent(x))
157+
function checkedstride(x::AbstractArray)
158+
szs::Dims = size(x)
159+
sts::Dims = strides(x)
160+
sz, st, n = Base.merge_adjacent_dim(szs, sts)
161+
n === ndims(x) || throw(ArgumentError("only support vector like inputs"))
162+
return st
163+
end
163164
## copy
164165

165166
"""

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ function pack(A, uplo)
1818
end
1919

2020
@testset "vec_pointer_stride" begin
21-
a = zeros(4,4,4)
22-
@test BLAS.asum(view(a,1:2:4,:,:)) == 0 # vector like
21+
a = float(rand(1:20,4,4,4))
22+
@test BLAS.asum(a) == sum(a) # dense case
23+
@test BLAS.asum(view(a,1:2:4,:,:)) == sum(view(a,1:2:4,:,:)) # vector like
24+
@test BLAS.asum(view(a,1:3,2:2,3:3)) == sum(view(a,1:3,2:2,3:3))
25+
@test BLAS.asum(view(a,1:1,1:3,1:1)) == sum(view(a,1:1,1:3,1:1))
26+
@test BLAS.asum(view(a,1:1,1:1,1:3)) == sum(view(a,1:1,1:1,1:3))
2327
@test_throws ArgumentError BLAS.asum(view(a,1:3:4,:,:)) # non-vector like
28+
@test_throws ArgumentError BLAS.asum(view(a,1:2,1:1,1:3))
2429
end
2530
Random.seed!(100)
2631
## BLAS tests - testing the interface code to BLAS routines

test/abstractarray.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,22 +1567,54 @@ end
15671567
@test reshape(r, :) === reshape(r, (:,)) === r
15681568
end
15691569

1570+
struct FakeZeroDimArray <: AbstractArray{Int, 0} end
1571+
Base.strides(::FakeZeroDimArray) = ()
1572+
Base.size(::FakeZeroDimArray) = ()
15701573
@testset "strides for ReshapedArray" begin
15711574
# Type-based contiguous check is tested in test/compiler/inline.jl
1575+
function check_strides(A::AbstractArray)
1576+
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
1577+
dims = ntuple(identity, ndims(A))
1578+
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
1579+
# Test strides via value check.
1580+
for i in eachindex(IndexLinear(), A)
1581+
A[i] === Base.unsafe_load(pointer(A, i)) || return false
1582+
end
1583+
return true
1584+
end
15721585
# General contiguous check
15731586
a = view(rand(10,10), 1:10, 1:10)
1574-
@test strides(vec(a)) == (1,)
1587+
@test check_strides(vec(a))
15751588
b = view(parent(a), 1:9, 1:10)
1576-
@test_throws "Parent must be contiguous." strides(vec(b))
1589+
@test_throws "Input is not strided." strides(vec(b))
15771590
# StridedVector parent
15781591
for n in 1:3
15791592
a = view(collect(1:60n), 1:n:60n)
1580-
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
1581-
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
1593+
@test check_strides(reshape(a, 3, 4, 5))
1594+
@test check_strides(reshape(a, 5, 6, 2))
15821595
b = view(parent(a), 60n:-n:1)
1583-
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
1584-
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
1596+
@test check_strides(reshape(b, 3, 4, 5))
1597+
@test check_strides(reshape(b, 5, 6, 2))
15851598
end
1599+
# StridedVector like parent
1600+
a = randn(10, 10, 10)
1601+
b = view(a, 1:10, 1:1, 5:5)
1602+
@test check_strides(reshape(b, 2, 5))
1603+
# Other StridedArray parent
1604+
a = view(randn(10,10), 1:9, 1:10)
1605+
@test check_strides(reshape(a,3,3,2,5))
1606+
@test check_strides(reshape(a,3,3,5,2))
1607+
@test check_strides(reshape(a,9,5,2))
1608+
@test check_strides(reshape(a,3,3,10))
1609+
@test check_strides(reshape(a,1,3,1,3,1,5,1,2))
1610+
@test check_strides(reshape(a,3,3,5,1,1,2,1,1))
1611+
@test_throws "Input is not strided." strides(reshape(a,3,6,5))
1612+
@test_throws "Input is not strided." strides(reshape(a,3,2,3,5))
1613+
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
1614+
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
1615+
# Zero dimensional parent
1616+
a = reshape(FakeZeroDimArray(),1,1,1)
1617+
@test @inferred(strides(a)) == (1, 1, 1)
15861618
end
15871619

15881620
@testset "stride for 0 dims array #44087" begin

0 commit comments

Comments
 (0)