Skip to content

Commit

Permalink
Extend strides(::ReshapedArray) with non-contiguous strided parent
Browse files Browse the repository at this point in the history
Use `Base.merge_adjacent_dim` to perform vector layout check before BLAS call.
  • Loading branch information
N5N3 committed Mar 22, 2022
1 parent 7cde4be commit 9c9d28e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 23 deletions.
49 changes: 43 additions & 6 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,51 @@ unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::Array) = true
_checkcontiguous(::Type{Bool}, A::AbstractArray) = false
# `strides(A::DenseArray)` calls `size_to_strides` by default.
# Thus it's OK to assume all `DenseArray`s are contiguously stored.
_checkcontiguous(::Type{Bool}, A::DenseArray) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
_checkcontiguous(Bool, a) && return size_to_strides(1, size(a)...)
apsz::Dims = size(a.parent)
apst::Dims = strides(a.parent)
msz, mst, n = merge_adjacent_dim(apsz, apst) # Try to perform "lazy" reshape
n == ndims(a.parent) && return size_to_strides(mst, size(a)...) # Parent is stridevector like
return _reshaped_strides(size(a), 1, msz, mst, n, apsz, apst)
end

function _reshaped_strides(::Dims{0}, reshaped::Int, msz::Int, ::Int, ::Int, ::Dims, ::Dims)
reshaped == msz || throw(ArgumentError("Input is not strided."))
()
end
function _reshaped_strides(sz::Dims, reshaped::Int, msz::Int, mst::Int, n::Int, apsz::Dims, apst::Dims)
st = reshaped * mst
reshaped = reshaped * sz[1]
if length(sz) > 1 && reshaped == msz && sz[2] != 1
msz, mst, n = merge_adjacent_dim(apsz, apst, n + 1)
reshaped = 1
end
sts = _reshaped_strides(tail(sz), reshaped, msz, mst, n, apsz, apst)
return (st, sts...)
end

merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0
merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1
function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N}
sz, st = apsz[n], apst[n]
while n < N
szₙ, stₙ = apsz[n+1], apst[n+1]
if sz == 1
sz, st = szₙ, stₙ
elseif stₙ == st * sz || szₙ == 1
sz *= szₙ
else
break
end
n += 1
end
return sz, st, n
end
19 changes: 10 additions & 9 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,19 @@ end
# Level 1
# A help function to pick the pointer and inc for 1d like inputs.
@inline function vec_pointer_stride(x::AbstractArray, stride0check = nothing)
isdense(x) && return pointer(x), 1 # simpify runtime check when possibe
ndims(x) == 1 || strides(x) == Base.size_to_strides(stride(x, 1), size(x)...) ||
throw(ArgumentError("only support vector like inputs"))
st = stride(x, 1)
Base._checkcontiguous(Bool, x) && return pointer(x), 1 # simpify runtime check when possibe
st, ptr = checkedstride(x), pointer(x)
isnothing(stride0check) || (st == 0 && throw(stride0check))
ptr = st > 0 ? pointer(x) : pointer(x, lastindex(x))
ptr += min(st, 0) * sizeof(eltype(x)) * (length(x) - 1)
ptr, st
end
isdense(x) = x isa DenseArray
isdense(x::Base.FastContiguousSubArray) = isdense(parent(x))
isdense(x::Base.ReshapedArray) = isdense(parent(x))
isdense(x::Base.ReinterpretArray) = isdense(parent(x))
function checkedstride(x::AbstractArray)
szs::Dims = size(x)
sts::Dims = strides(x)
sz, st, n = Base.merge_adjacent_dim(szs, sts)
n === ndims(x) || throw(ArgumentError("only support vector like inputs"))
return st
end
## copy

"""
Expand Down
9 changes: 7 additions & 2 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ function pack(A, uplo)
end

@testset "vec_pointer_stride" begin
a = zeros(4,4,4)
@test BLAS.asum(view(a,1:2:4,:,:)) == 0 # vector like
a = float(rand(1:20,4,4,4))
@test BLAS.asum(a) == sum(a) # dense case
@test BLAS.asum(view(a,1:2:4,:,:)) == sum(view(a,1:2:4,:,:)) # vector like
@test BLAS.asum(view(a,1:3,2:2,3:3)) == sum(view(a,1:3,2:2,3:3))
@test BLAS.asum(view(a,1:1,1:3,1:1)) == sum(view(a,1:1,1:3,1:1))
@test BLAS.asum(view(a,1:1,1:1,1:3)) == sum(view(a,1:1,1:1,1:3))
@test_throws ArgumentError BLAS.asum(view(a,1:3:4,:,:)) # non-vector like
@test_throws ArgumentError BLAS.asum(view(a,1:2,1:1,1:3))
end
Random.seed!(100)
## BLAS tests - testing the interface code to BLAS routines
Expand Down
44 changes: 38 additions & 6 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1567,22 +1567,54 @@ end
@test reshape(r, :) === reshape(r, (:,)) === r
end

struct FakeZeroDimArray <: AbstractArray{Int, 0} end
Base.strides(::FakeZeroDimArray) = ()
Base.size(::FakeZeroDimArray) = ()
@testset "strides for ReshapedArray" begin
# Type-based contiguous check is tested in test/compiler/inline.jl
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test strides(vec(a)) == (1,)
@test check_strides(vec(a))
b = view(parent(a), 1:9, 1:10)
@test_throws "Parent must be contiguous." strides(vec(b))
@test_throws "Input is not strided." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
@test check_strides(reshape(a, 3, 4, 5))
@test check_strides(reshape(a, 5, 6, 2))
b = view(parent(a), 60n:-n:1)
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
@test check_strides(reshape(b, 3, 4, 5))
@test check_strides(reshape(b, 5, 6, 2))
end
# StridedVector like parent
a = randn(10, 10, 10)
b = view(a, 1:10, 1:1, 5:5)
@test check_strides(reshape(b, 2, 5))
# Other StridedArray parent
a = view(randn(10,10), 1:9, 1:10)
@test check_strides(reshape(a,3,3,2,5))
@test check_strides(reshape(a,3,3,5,2))
@test check_strides(reshape(a,9,5,2))
@test check_strides(reshape(a,3,3,10))
@test check_strides(reshape(a,1,3,1,3,1,5,1,2))
@test check_strides(reshape(a,3,3,5,1,1,2,1,1))
@test_throws "Input is not strided." strides(reshape(a,3,6,5))
@test_throws "Input is not strided." strides(reshape(a,3,2,3,5))
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
# Zero dimensional parent
a = reshape(FakeZeroDimArray(),1,1,1)
@test @inferred(strides(a)) == (1, 1, 1)
end

@testset "stride for 0 dims array #44087" begin
Expand Down

0 comments on commit 9c9d28e

Please sign in to comment.