Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layouts - combining array structure and indexing #141

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr
_int_or_static_int(::Nothing) = Int
_int_or_static_int(x::Int) = StaticInt{x}
_int(i::Integer) = Int(i)
_int(i::Int) = i
_int(i::StaticInt) = i

static_ndims(x) = static(ndims(x))

@static if VERSION >= v"1.7.0-DEV.421"
using Base: @aggressive_constprop
else
Expand Down Expand Up @@ -92,6 +95,7 @@ known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
known_length(::Type{T}) where {T<:Slice} = known_length(parent_type(T))
known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N
known_length(::Type{T}) where {Itr,T<:Base.Generator{Itr}} = known_length(Itr)
known_length(::Type{T}) where {N,T<:AbstractCartesianIndex{N}} = N
known_length(::Type{<:Number}) = 1
function known_length(::Type{T}) where {T}
if parent_type(T) <: T
Expand Down Expand Up @@ -842,6 +846,7 @@ end
end

include("ranges.jl")
include("layouts.jl")
include("indexing.jl")
include("dimensions.jl")
include("axes.jl")
Expand Down
4 changes: 3 additions & 1 deletion src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ end
Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
axes_types(::Type{T}) where {T<:Array} = Tuple{Vararg{OneTo{Int},ndims(T)}}
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
Expand Down Expand Up @@ -141,6 +140,9 @@ axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO imple

Return a tuple of ranges where each range maps to each element along a dimension of `A`.
"""
@inline axes(a::Array) = _array_axes(size(a))
@inline _array_axes(x::Tuple{Vararg{Int}}) = (static(1):first(x), _array_axes(tail(x))...)
_array_axes(::Tuple{}) = ()
@inline function axes(a::A) where {A}
if parent_type(A) <: A
return Base.axes(a)
Expand Down
4 changes: 2 additions & 2 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A
dim_i = 1
for i in 1:ndims(A)
p = I.parameters[i]
if argdims(A, p) > 0
if index_dims_out(A, p) > 0
push!(out.args, :(StaticInt($dim_i)))
dim_i += 1
else
Expand Down Expand Up @@ -98,7 +98,7 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
out = Expr(:tuple)
n = 1
for p in I.parameters
if argdims(A, p) > 0
if index_dims_out(A, p) > 0
push!(out.args, :(StaticInt($n)))
end
n += 1
Expand Down
594 changes: 402 additions & 192 deletions src/indexing.jl

Large diffs are not rendered by default.

139 changes: 139 additions & 0 deletions src/layouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@

_as_index(x) = x
_as_index(x::OneTo) = static(1):length(x)
_as_index(x::StepRange) = OptionallyStaticStepRange(x)
_as_index(x::UnitRange) = OptionallyStaticUnitRange(x)
_as_index(x::OptionallyStaticRange) = x

"""
StrideLayout(A)

Produces an array whose elements correspond to the linear buffer position of `A`'s elements.
"""
struct StrideLayout{N,R,O1,S,A<:Tuple{Vararg{Any,N}}} <: AbstractArray2{Int,N}
rank::R
offset1::O1
strides::S
axes::A
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the VectorizationBase side of things, it'd be nice if StridedPointer could be a StrideLayout + a Ptr.
However, StrideLayout is missing some things, most importantly the contiguous axis, and also has axes which it does not want (although it does want to store map(static_first, axes(x))).

Also, StridedPointer + axes == PtrArray. StrideLayout already subtypes AbstractArray2, so it is almost there itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something that is indexable but isn't necessarily an iterator like this

struct StrideLayout{N,S<:Tuple{Vararg{Any,N}},R,O,O1,C}
    strides::S
    rank::R
    offset1::O1
    offsets::O
    contiguous::C
end

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any actual support anywhere for ArrayInterface.contiguous_batch_size yet, so we could add that later, but probably best to start carrying that around now as well.
But this looks good.


offset1(x::StrideLayout) = getfield(x, :offset1)
offsets(x::StrideLayout) = map(static_first, axes(x))
axes(x::StrideLayout) = getfield(x, :axes)
@inline function axes(x::StrideLayout{N}, i::Int) where {N}
if i > N
return static(1):1
else
return getfield(getfield(x, :axes), i)
end
end
@inline function axes(x::StrideLayout{N}, ::StaticInt{i}) where {N,i}
if i > N
return static(1):static(1)
else
return getfield(getfield(x, :axes), i)
end
end
strides(x::StrideLayout) = getfield(x, :strides)
stride_rank(x::StrideLayout) = getfield(x, :rank)

@inline function StrideLayout(x::DenseArray)
a = axes(x)
return StrideLayout(
stride_rank(x),
offset1(x),
size_to_strides(map(static_length, a), static(1)),
a
)
end

# TODO optimize this
@inline function StrideLayout(x)
return StrideLayout(
stride_rank(x),
offset1(x),
strides(x),
axes(x)
)
end

##############
### layout ###
##############
layout(x, i) = layout(x)
layout(x, i::AbstractVector{<:Integer}) = _maybe_linear_layout(IndexStyle(x), x)
layout(x, i::Integer) = _maybe_linear_layout(IndexStyle(x), x)
layout(x, i::AbstractCartesianIndex{1}) = _maybe_linear_layout(IndexStyle(x), x)
function layout(x, i::AbstractVector{AbstractCartesianIndex{1}})
return _maybe_linear_layout(IndexStyle(x), x)
end
_maybe_linear_layout(::IndexLinear, x) = _as_index(eachindex(x))
_maybe_linear_layout(::IndexStyle, x) = layout(x)
layout(x::StrideLayout) = x
layout(x::LinearIndices) = x
layout(x::CartesianIndices) = x
function layout(x)
if defines_strides(x)
return StrideLayout(x)
else
return _layout_indices(IndexStyle(x), axes(x))
end
end
function layout(x::Transpose)
if defines_strides(x)
return StrideLayout(x)
else
return Transpose(layout(parent(x)))
end
end
function layout(x::Adjoint{T}) where {T<:Number}
if defines_strides(x)
return StrideLayout(x)
else
return Transpose(layout(parent(x)))
end
end
function layout(x::PermutedDimsArray{T,N,perm,iperm}) where {T,N,perm,iperm}
if defines_strides(x)
return StrideLayout(x)
else
p = layout(parent(x))
return PermutedDimsArray{eltype(p),ndims(p),perm, iperm,typeof(p)}(p)
end
end
function layout(x::SubArray)
if defines_strides(x)
return StrideLayout(x)
else
return @inbounds(view(layout(parent(x)), x.indices...))
end
end
_layout_indices(::IndexStyle, axs) = CartesianIndices(axs)
_layout_indices(::IndexLinear, axs) = LinearIndices(axs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether to use linear or Cartesian indexing is of course context dependent (in the rewrite, I'll add support for switching between representations).
E.g., if you're summing a Matrix{Float64} or a Adjoint{Float64,Matrix{Float64}}, you'd want to use linear indexing in both cases (for the Adjoint, the linear indices would be traversing it in row major order). Same if you're calculating the sum of 2 instances of Matrix{Float64} or 2 instances of Adjoint{Float64,Matrix{Float64}}. However, if you sum add a Matrix{Float64} to a Adjoint{Float64,Matrix{Float64}}, you suddenly need cartesian indexing.
eachindex gets some of this right, but it'll always return IndexCartesian() for the Adjoint, meaning performance for those sums/additions is worse than necessary.

Is there a "combine layouts"?

I'll get LoopVectorization to be able to transform between these correctly in the rewrite.
One of my goals there is for something like this to work:

@avx begin
  C .= 0
  for n in indices((C,B),2), m in indices((C,A),1), for k in indices((A,B),(2,1))
    C[m,n] += A[m,k] * B[k,n]
  end
end

and to have LoopVectorization be capable of converting the linear indexing of C .= 0 into cartesian indexing, and then fuse that broadcast with the following loop.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you're describing is basically what I was getting at with the reference to MLIR. The final "layout" after code gen would be the result of scoping (for summation this would be unordered access to all elements once) which informs how to combine each layer of the arrays.


"""
buffer(x)

Return the raw buffer for `x`, stripping any additional info (structural, indexing,
metadata, etc.).
"""
buffer(x) = x
@inline buffer(x::PermutedDimsArray) = buffer(parent(x))
@inline buffer(x::Transpose) = buffer(parent(x))
@inline buffer(x::Adjoint) = buffer(parent(x))
@inline buffer(x::SubArray) = buffer(parent(x))


""" allocate_memory(::AbstractDevice, ::Type{T}, length::Union{StaticInt,Int}) """
allocate_memory(::CPUPointer, ::Type{T}, ::StaticInt{N}) where {T,N} = Ref{NTuple{N,T}}
allocate_memory(::CPUPointer, ::Type{T}, n::Int) where {T} = Vector{T}(undef, n)
allocate_memory(::CPUTuple, ::Type{T}, ::StaticInt{N}) where {T,N} = Ref{NTuple{N,T}}


""" dereference(::AbstractDevice, x) """
dereference(::CPUPointer, x) = x
dereference(::CPUTuple, x::Ref) = x[]

""" initialize(data, layout) """
function initialize end

16 changes: 16 additions & 0 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,22 @@ unsafe_length_one_to(::StaticInt{L}) where {L} = L
end
end

@propagate_inbounds function Base.getindex(
r::OptionallyStaticUnitRange,
s::StepRange{T}
) where {T<:Integer}

@boundscheck checkbounds(r, s)
if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Int(last(s)))
else
start = first(r) + s.start - 1
st = step(s)
stop = ((length(s) - 1) * st) + start
return OptionallyStaticStepRange(start, st, stop)
end
end

@propagate_inbounds function Base.getindex(
r::OptionallyStaticUnitRange,
s::AbstractUnitRange{<:Integer},
Expand Down
1 change: 1 addition & 0 deletions src/size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function size(a::A) where {A}
return size(parent(a))
end
end
size(x::Array) = Base.size(x)
#size(a::AbstractVector) = (size(a, One()),)

size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
Expand Down
3 changes: 3 additions & 0 deletions src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Returns the offset of the linear indices for `x`.
offset1(x) = _offset1(has_parent(x), x)
_offset1(::True, x) = offset1(parent(x))
_offset1(::False, x) = static(1)
function offset1(x::SubArray)
return unsafe_get_element(layout(parent(x)), NDIndex(map(static_first, x.indices)))
end

"""
contiguous_axis(::Type{T}) -> StaticInt{N}
Expand Down
4 changes: 2 additions & 2 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ end
@test @inferred(ArrayInterface.size(y')) == (1, size(parent(x), 1))
@test @inferred(axes(x, first(d))) == axes(parent(x), 1)
@test strides(x, :x) == ArrayInterface.strides(parent(x))[1]
@test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int}
@test ArrayInterface.axes_types(x, :x) <: Base.OneTo{Int}
@test @inferred(ArrayInterface.axes_types(x, static(:x))) <: AbstractUnitRange{Int}
@test ArrayInterface.axes_types(x, :x) <: AbstractUnitRange{Int}
@test @inferred(ArrayInterface.axes_types(LinearIndices{2,NTuple{2,Base.OneTo{Int}}})) <: NTuple{2,Base.OneTo{Int}}
CI = CartesianIndices{2,Tuple{Base.OneTo{Int},UnitRange{Int}}}
@test @inferred(ArrayInterface.axes_types(CI, static(1))) <: Base.OneTo{Int}
Expand Down
37 changes: 27 additions & 10 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using ArrayInterface: NDIndex
0.047 ns (0 allocations: 0 bytes)
=#

@testset "argdims" begin
@test @inferred(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), (1, CartesianIndex(1,2)))) === static((0, 2))
@test @inferred(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), (1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) === static((0, 2))
@test @inferred(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), (1, CartesianIndex((2,2))))) === static((0, 2))
@test @inferred(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), (CartesianIndex((2,2)), :, :))) === static((2, 1, 1))
@test @inferred(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), Vector{Int})) === static(1)
@testset "index_dims_in" begin
@test @inferred(ArrayInterface.index_dims_in(IndexLinear(), (1, CartesianIndex(1,2)))) === static((1, 2))
@test @inferred(ArrayInterface.index_dims_in(IndexLinear(), (1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) === static((1, 2))
@test @inferred(ArrayInterface.index_dims_in(IndexLinear(), (1, CartesianIndex((2,2))))) === static((1, 2))
@test @inferred(ArrayInterface.index_dims_in(IndexLinear(), (CartesianIndex((2,2)), :, :))) === static((2, 1, 1))
@test @inferred(ArrayInterface.index_dims_in(IndexLinear(), Vector{Int})) === static(1)
end

@testset "to_index" begin
Expand All @@ -24,9 +24,11 @@ end
@test @inferred(ArrayInterface.to_index(axis, [true, false, false])) == [1]
@test @inferred(ArrayInterface.to_index(axis, CartesianIndices(()))) === CartesianIndices(())

#=
x = LinearIndices((static(0):static(3),static(3):static(5),static(-2):static(0)));
@test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
@test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)
=#

@test_throws BoundsError ArrayInterface.to_index(axis, 4)
@test_throws BoundsError ArrayInterface.to_index(axis, 1:4)
Expand Down Expand Up @@ -115,7 +117,7 @@ end
# which returns a UnitRange. Instead we try to preserve axes if at all possible so the
# values are the same but it's still wrapped in LinearIndices struct
@test @inferred(ArrayInterface.getindex(LinearIndices((3,)), 1:2)) == 1:2
@test @inferred(ArrayInterface.getindex(LinearIndices((3,)), 1:2:3)) === 1:2:3
@test @inferred(ArrayInterface.getindex(LinearIndices((3,)), 1:2:3)) == 1:2:3
@test_throws BoundsError ArrayInterface.getindex(LinearIndices((3,)), 2:4)
@test_throws BoundsError ArrayInterface.getindex(CartesianIndices((3,)), 2, 2)
# ambiguity btw cartesian indexing and linear indexing in 1d when
Expand All @@ -124,8 +126,8 @@ end
#@test_throws ArgumentError Base._sub2ind((1:3,), 2)
#@test_throws ArgumentError Base._ind2sub((1:3,), 2)
x = Array{Int,2}(undef, (2, 2))
ArrayInterface.unsafe_set_index!(x, 1, (2, 2))
@test ArrayInterface.unsafe_get_index(x, (2, 2)) === 1
ArrayInterface.unsafe_setindex!(x, 1, (2, 2))
@test ArrayInterface.unsafe_getindex(x, (2, 2)) === 1

# FIXME @test_throws MethodError ArrayInterface.unsafe_set_element!(x, 1, (:x, :x))
# FIXME @test_throws MethodError ArrayInterface.unsafe_get_element(x, (:x, :x))
Expand Down Expand Up @@ -156,7 +158,7 @@ end
@test @inferred(ArrayInterface.getindex(cartesian, cartesian)) == cartesian
@test @inferred(ArrayInterface.getindex(cartesian, vec(cartesian))) == vec(cartesian)
@test @inferred(ArrayInterface.getindex(linear, 2:3)) === 2:3
@test @inferred(ArrayInterface.getindex(linear, 3:-1:1)) === 3:-1:1
@test @inferred(ArrayInterface.getindex(linear, 3:-1:1)) == 3:-1:1
@test_throws BoundsError ArrayInterface.getindex(linear, 4:13)
end

Expand Down Expand Up @@ -193,3 +195,18 @@ end
end
end

@testset "stride indexing" begin
x = Array{Int,3}(undef, (4,4,4));
x[:] = 1:length(x);
p = PermutedDimsArray(x, (3, 1, 2));
v = view(x, :, 2, :);
@test ArrayInterface.getindex(x, :, :, :) == Base.getindex(x, :, :, :)
@test ArrayInterface.getindex(x, 3, :, :) == Base.getindex(x, 3, :, :)
@test ArrayInterface.getindex(x, :, 3, :) == Base.getindex(x, :, 3, :)

@test ArrayInterface.getindex(p, :, :, :) == Base.getindex(p, :, :, :)
@test ArrayInterface.getindex(p, 3, :, :) == Base.getindex(p, 3, :, :)
@test ArrayInterface.getindex(p, :, 3, :) == Base.getindex(p, :, 3, :)
@test ArrayInterface.getindex(v, :, :) == getindex(v, :, :)
end