Skip to content

Commit e0a4b77

Browse files
authored
Generalize strides for ReinterpretArray and ReshapedArray (#44027)
1 parent 5181e36 commit e0a4b77

File tree

5 files changed

+140
-46
lines changed

5 files changed

+140
-46
lines changed

base/reinterpretarray.jl

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
4343
if N != 0 && sizeof(S) != sizeof(T)
4444
ax1 = axes(a)[1]
4545
dim = length(ax1)
46-
if Base.issingletontype(T)
46+
if issingletontype(T)
4747
dim == 0 || throwsingleton(S, T, "a non-empty")
4848
else
4949
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
@@ -75,15 +75,15 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
7575
if sizeof(S) == sizeof(T)
7676
N = ndims(a)
7777
elseif sizeof(S) > sizeof(T)
78-
Base.issingletontype(T) && throwsingleton(S, T, "with reshape a")
78+
issingletontype(T) && throwsingleton(S, T, "with reshape a")
7979
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
8080
N = ndims(a) + 1
8181
else
82-
Base.issingletontype(S) && throwfromsingleton(S, T)
82+
issingletontype(S) && throwfromsingleton(S, T)
8383
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
8484
N = ndims(a) - 1
8585
N > -1 || throwsize0(S, T, "larger")
86-
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
86+
axes(a, 1) == OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
8787
end
8888
readable = array_subpadding(T, S)
8989
writable = array_subpadding(S, T)
@@ -148,33 +148,39 @@ StridedVector{T} = StridedArray{T,1}
148148
StridedMatrix{T} = StridedArray{T,2}
149149
StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}
150150

151-
# the definition of strides for Array{T,N} is tuple() if N = 0, otherwise it is
152-
# a tuple containing 1 and a cumulative product of the first N-1 sizes
153-
# this definition is also used for StridedReshapedArray and StridedReinterpretedArray
154-
# which have the same memory storage as Array
155-
stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, i::Int) = _stride(a, i)
156-
157-
function stride(a::ReinterpretArray, i::Int)
158-
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
159-
return _stride(a, i)
160-
end
151+
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
161152

162-
function _stride(a, i)
163-
if i > ndims(a)
164-
return length(a)
153+
function strides(a::ReshapedReinterpretArray)
154+
ap = parent(a)
155+
els, elp = elsize(a), elsize(ap)
156+
stp = strides(ap)
157+
els == elp && return stp
158+
els < elp && return (1, _checked_strides(stp, els, elp)...)
159+
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
160+
return _checked_strides(tail(stp), els, elp)
161+
end
162+
163+
function strides(a::NonReshapedReinterpretArray)
164+
ap = parent(a)
165+
els, elp = elsize(a), elsize(ap)
166+
stp = strides(ap)
167+
els == elp && return stp
168+
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
169+
return (1, _checked_strides(tail(stp), els, elp)...)
170+
end
171+
172+
@inline function _checked_strides(stp::Tuple, els::Integer, elp::Integer)
173+
if elp > els && rem(elp, els) == 0
174+
N = div(elp, els)
175+
return map(i -> N * i, stp)
165176
end
166-
s = 1
167-
for n = 1:(i-1)
168-
s *= size(a, n)
169-
end
170-
return s
177+
drs = map(i -> divrem(elp * i, els), stp)
178+
all(i->iszero(i[2]), drs) ||
179+
throw(ArgumentError("Parent's strides could not be exactly divided!"))
180+
map(first, drs)
171181
end
172182

173-
function strides(a::ReinterpretArray)
174-
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
175-
size_to_strides(1, size(a)...)
176-
end
177-
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
183+
_checkcontiguous(::Type{Bool}, A::ReinterpretArray) = _checkcontiguous(Bool, parent(A))
178184

179185
similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d)
180186

@@ -227,12 +233,12 @@ SCartesianIndices2{K}(indices2::AbstractUnitRange{Int}) where {K} = (@assert K::
227233
eachindex(::IndexSCartesian2{K}, A::ReshapedReinterpretArray) where {K} = SCartesianIndices2{K}(eachindex(IndexLinear(), parent(A)))
228234
@inline function eachindex(style::IndexSCartesian2{K}, A::AbstractArray, B::AbstractArray...) where {K}
229235
iter = eachindex(style, A)
230-
Base._all_match_first(C->eachindex(style, C), iter, B...) || Base.throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
236+
_all_match_first(C->eachindex(style, C), iter, B...) || throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
231237
return iter
232238
end
233239

234240
size(iter::SCartesianIndices2{K}) where K = (K, length(iter.indices2))
235-
axes(iter::SCartesianIndices2{K}) where K = (Base.OneTo(K), iter.indices2)
241+
axes(iter::SCartesianIndices2{K}) where K = (OneTo(K), iter.indices2)
236242

237243
first(iter::SCartesianIndices2{K}) where {K} = SCartesianIndex2{K}(1, first(iter.indices2))
238244
last(iter::SCartesianIndices2{K}) where {K} = SCartesianIndex2{K}(K, last(iter.indices2))
@@ -300,27 +306,27 @@ unaliascopy(a::ReshapedReinterpretArray{T}) where {T} = reinterpret(reshape, T,
300306

301307
function size(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
302308
psize = size(a.parent)
303-
size1 = Base.issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
309+
size1 = issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
304310
tuple(size1, tail(psize)...)
305311
end
306312
function size(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
307313
psize = size(a.parent)
308314
sizeof(S) > sizeof(T) && return (div(sizeof(S), sizeof(T)), psize...)
309-
sizeof(S) < sizeof(T) && return Base.tail(psize)
315+
sizeof(S) < sizeof(T) && return tail(psize)
310316
return psize
311317
end
312318
size(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
313319

314320
function axes(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
315321
paxs = axes(a.parent)
316322
f, l = first(paxs[1]), length(paxs[1])
317-
size1 = Base.issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
323+
size1 = issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
318324
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
319325
end
320326
function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
321327
paxs = axes(a.parent)
322-
sizeof(S) > sizeof(T) && return (Base.OneTo(div(sizeof(S), sizeof(T))), paxs...)
323-
sizeof(S) < sizeof(T) && return Base.tail(paxs)
328+
sizeof(S) > sizeof(T) && return (OneTo(div(sizeof(S), sizeof(T))), paxs...)
329+
sizeof(S) < sizeof(T) && return tail(paxs)
324330
return paxs
325331
end
326332
axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
@@ -372,7 +378,7 @@ end
372378
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
373379
# Make sure to match the scalar reinterpret if that is applicable
374380
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
375-
if Base.issingletontype(T) # singleton types
381+
if issingletontype(T) # singleton types
376382
@boundscheck checkbounds(a, i1, tailinds...)
377383
return T.instance
378384
end
@@ -420,7 +426,7 @@ end
420426
@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
421427
# Make sure to match the scalar reinterpret if that is applicable
422428
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
423-
if Base.issingletontype(T) # singleton types
429+
if issingletontype(T) # singleton types
424430
@boundscheck checkbounds(a, i1, tailinds...)
425431
return T.instance
426432
end
@@ -511,7 +517,7 @@ end
511517
v = convert(T, v)::T
512518
# Make sure to match the scalar reinterpret if that is applicable
513519
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
514-
if Base.issingletontype(T) # singleton types
520+
if issingletontype(T) # singleton types
515521
@boundscheck checkbounds(a, i1, tailinds...)
516522
# setindex! is a noop except for the index check
517523
else
@@ -577,7 +583,7 @@ end
577583
v = convert(T, v)::T
578584
# Make sure to match the scalar reinterpret if that is applicable
579585
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
580-
if Base.issingletontype(T) # singleton types
586+
if issingletontype(T) # singleton types
581587
@boundscheck checkbounds(a, i1, tailinds...)
582588
# setindex! is a noop except for the index check
583589
else

base/reshapedarray.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ end
242242

243243
@inline function _unsafe_getindex(A::ReshapedArray{T,N}, indices::Vararg{Int,N}) where {T,N}
244244
axp = axes(A.parent)
245-
i = offset_if_vec(Base._sub2ind(size(A), indices...), axp)
245+
i = offset_if_vec(_sub2ind(size(A), indices...), axp)
246246
I = ind2sub_rs(axp, A.mi, i)
247247
_unsafe_getindex_rs(parent(A), I)
248248
end
@@ -266,7 +266,7 @@ end
266266

267267
@inline function _unsafe_setindex!(A::ReshapedArray{T,N}, val, indices::Vararg{Int,N}) where {T,N}
268268
axp = axes(A.parent)
269-
i = offset_if_vec(Base._sub2ind(size(A), indices...), axp)
269+
i = offset_if_vec(_sub2ind(size(A), indices...), axp)
270270
@inbounds parent(A)[ind2sub_rs(axes(A.parent), A.mi, i)...] = val
271271
val
272272
end
@@ -292,3 +292,16 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
292292
(size_to_strides(strds[1], size(I[1])...)..., substrides(tail(strds), tail(I))...)
293293
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} =
294294
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)
295+
296+
297+
_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
298+
_checkcontiguous(::Type{Bool}, A::Array) = true
299+
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
300+
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))
301+
302+
function strides(a::ReshapedArray)
303+
# We can handle non-contiguous parent if it's a StridedVector
304+
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
305+
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
306+
size_to_strides(1, size(a)...)
307+
end

test/abstractarray.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,3 +1561,21 @@ end
15611561
r = Base.IdentityUnitRange(3:4)
15621562
@test reshape(r, :) === reshape(r, (:,)) === r
15631563
end
1564+
1565+
@testset "strides for ReshapedArray" begin
1566+
# Type-based contiguous check is tested in test/compiler/inline.jl
1567+
# General contiguous check
1568+
a = view(rand(10,10), 1:10, 1:10)
1569+
@test strides(vec(a)) == (1,)
1570+
b = view(parent(a), 1:9, 1:10)
1571+
@test_throws "Parent must be contiguous." strides(vec(b))
1572+
# StridedVector parent
1573+
for n in 1:3
1574+
a = view(collect(1:60n), 1:n:60n)
1575+
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
1576+
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
1577+
b = view(parent(a), 60n:-n:1)
1578+
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
1579+
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
1580+
end
1581+
end

test/compiler/inline.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,10 @@ end
907907
@test fully_eliminated((String,)) do x
908908
Base.@invoke conditional_escape!(false::Any, x::Any)
909909
end
910+
911+
@testset "strides for ReshapedArray (PR#44027)" begin
912+
# Type-based contiguous check
913+
a = vec(reinterpret(reshape,Int16,reshape(view(reinterpret(Int32,randn(10)),2:11),5,:)))
914+
f(a) = only(strides(a));
915+
@test fully_eliminated(f, Tuple{typeof(a)}) && f(a) == 1
916+
end

test/reinterpretarray.jl

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,62 @@ let A = collect(reshape(1:20, 5, 4))
157157
@test reshape(R, :) isa StridedArray
158158
end
159159

160-
# and ensure a reinterpret array containing a strided array can have strides computed
161-
let A = view(reinterpret(Int16, collect(reshape(UnitRange{Int64}(1, 20), 5, 4))), :, 1:2)
162-
R = reinterpret(Int32, A)
163-
@test strides(R) == (1, 10)
164-
@test stride(R, 1) == 1
165-
@test stride(R, 2) == 10
160+
function check_strides(A::AbstractArray)
161+
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
162+
dims = ntuple(identity, ndims(A))
163+
map(i -> stride(A, i), dims) == strides(A) || return false
164+
# Test strides via value check.
165+
for i in eachindex(IndexLinear(), A)
166+
A[i] === Base.unsafe_load(pointer(A, i)) || return false
167+
end
168+
return true
169+
end
170+
171+
@testset "strides for NonReshapedReinterpretArray" begin
172+
A = Array{Int32}(reshape(1:88, 11, 8))
173+
for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1, 2:3:8, 7:-6:1, 3:5:11)
174+
# dim1 is contiguous
175+
for T in (Int16, Float32)
176+
@test check_strides(reinterpret(T, view(A, 1:8, viewax2)))
177+
end
178+
if mod(step(viewax2), 2) == 0
179+
@test check_strides(reinterpret(Int64, view(A, 1:8, viewax2)))
180+
else
181+
@test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2)))
182+
end
183+
# non-integer-multipled classified
184+
if mod(step(viewax2), 3) == 0
185+
@test check_strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
186+
else
187+
@test_throws "Parent's strides" strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
188+
end
189+
if mod(step(viewax2), 5) == 0
190+
@test check_strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
191+
else
192+
@test_throws "Parent's strides" strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
193+
end
194+
# dim1 is not contiguous
195+
for T in (Int16, Int64)
196+
@test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2)))
197+
end
198+
@test check_strides(reinterpret(Float32, view(A, 8:-1:1, viewax2)))
199+
end
200+
end
201+
202+
@testset "strides for ReshapedReinterpretArray" begin
203+
A = Array{Int32}(reshape(1:192, 3, 8, 8))
204+
for viewax1 in (1:8, 1:2:8, 8:-1:1, 8:-2:1), viewax2 in (1:2, 4:-1:1)
205+
for T in (Int16, Float32)
206+
@test check_strides(reinterpret(reshape, T, view(A, 1:2, viewax1, viewax2)))
207+
@test check_strides(reinterpret(reshape, T, view(A, 1:2:3, viewax1, viewax2)))
208+
end
209+
if mod(step(viewax1), 2) == 0
210+
@test check_strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
211+
else
212+
@test_throws "Parent's strides" strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
213+
end
214+
@test_throws "Parent must" strides(reinterpret(reshape, Int64, view(A, 1:2:3, viewax1, viewax2)))
215+
end
166216
end
167217

168218
@testset "strides" begin

0 commit comments

Comments
 (0)