Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NDIndex - Optionally static CartesianIndex #140

Merged
merged 12 commits into from
Apr 7, 2021
9 changes: 5 additions & 4 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ else
end
end

static_ndims(x) = static(ndims(x))

if VERSION ≥ v"1.6.0-DEV.1581"
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false
Expand All @@ -51,6 +49,8 @@ const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
@inline static_last(x) = Static.maybe_static(known_last, last, x)
@inline static_step(x) = Static.maybe_static(known_step, step, x)

include("ndindex.jl")

"""
parent_type(::Type{T})

Expand All @@ -70,6 +70,7 @@ parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
parent_type(::Type{LoTri{T,M}}) where {T,M} = M
parent_type(::Type{UpTri{T,M}}) where {T,M} = M
parent_type(::Type{Diagonal{T,V}}) where {T,V} = V

"""
has_parent(::Type{T}) -> StaticBool

Expand Down Expand Up @@ -591,7 +592,7 @@ safevec(v::Number) = v
safevec(v::AbstractVector) = v

"""
zeromatrix(u::AbstractVector)
zeromatrix(u::AbstractVector)

Creates the zero'd matrix version of `u`. Note that this is unique because
`similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
Expand All @@ -607,7 +608,7 @@ function zeromatrix(u)
end

"""
restructure(x,y)
restructure(x,y)

Restructures the object `y` into a shape of `x`, keeping its values intact. For
simple objects like an `Array`, this simply amounts to a reshape. However, for
Expand Down
203 changes: 91 additions & 112 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,60 +28,25 @@ argdims(s::ArrayStyle, arg) = argdims(s, typeof(arg))
argdims(::ArrayStyle, ::Type{T}) where {T} = static(0)
argdims(::ArrayStyle, ::Type{T}) where {T<:Colon} = static(1)
argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = static(ndims(T))
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = static(N)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = static(N)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractCartesianIndex{N}} = static(N)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:AbstractCartesianIndex{N}}} = static(N)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = static(N)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N)
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return eachop(_argdims, nstatic(Val(N)), s, T)
end

is_element_index(i) = is_element_index(typeof(i))
is_element_index(::Type{T}) where {T} = static(false)
is_element_index(::Type{T}) where {T<:AbstractCartesianIndex} = static(true)
is_element_index(::Type{T}) where {T<:Integer} = static(true)
_is_element_index(::Type{T}, i::StaticInt) where {T} = is_element_index(_get_tuple(T, i))
function is_element_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return static(all(eachop(_is_element_index, nstatic(Val(N)), T)))
end

"""
UnsafeIndex(::ArrayStyle, ::Type{I})

`UnsafeIndex` controls how indices that have been bounds checked and converted to
native axes' indices are used to return the stored values of an array. For example,
if the indices at each dimension are single integers then `UnsafeIndex(array, inds)` returns
`UnsafeGetElement()`. Conversely, if any of the indices are vectors then `UnsafeGetCollection()`
is returned, indicating that a new array needs to be reconstructed. This method permits
customizing the terminal behavior of the indexing pipeline based on arguments passed
to `ArrayInterface.getindex`. New subtypes of `UnsafeIndex` should define `promote_rule`.
"""
abstract type UnsafeIndex end

struct UnsafeGetElement <: UnsafeIndex end

struct UnsafeGetCollection <: UnsafeIndex end

UnsafeIndex(x, i) = UnsafeIndex(x, typeof(i))
UnsafeIndex(x, ::Type{I}) where {I} = UnsafeIndex(ArrayStyle(x), I)
UnsafeIndex(s::ArrayStyle, i) = UnsafeIndex(s, typeof(i))
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I} = UnsafeGetElement()
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I<:AbstractArray} = UnsafeGetCollection()

Base.promote_rule(::Type{X}, ::Type{Y}) where {X<:UnsafeIndex,Y<:UnsafeGetElement} = X

@generated function UnsafeIndex(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
if N === 0
return UnsafeGetElement()
else
e = Expr(:call, promote_type)
for p in T.parameters
push!(e.args, :(typeof(ArrayInterface.UnsafeIndex(s, $p))))
end
return Expr(:block, Expr(:meta, :inline), Expr(:call, e))
end
_is_element_index(i) = _is_element_index(typeof(i))
_is_element_index(::Type{T}) where {T} = static(false)
_is_element_index(::Type{T}) where {T<:AbstractCartesianIndex} = static(true)
_is_element_index(::Type{T}) where {T<:Integer} = static(true)
__is_element_index(::Type{T}, i::StaticInt) where {T} = _is_element_index(_get_tuple(T, i))
function _is_element_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return static(all(eachop(__is_element_index, nstatic(Val(N)), T)))
end
# empty tuples refer to the single element of 0-dimensional arrays
_is_element_index(::Type{Tuple{}}) = static(true)

# are the indexing arguments provided a linear collection into a multidim collection
is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = argdims(A, Arg) < 2
Expand Down Expand Up @@ -181,6 +146,22 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1})
return to_index(axis, first(Tuple(arg)))
end
function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N}
inds = Tuple(arg)
o = offsets(x)
s = size(x)
return first(inds) + (offset1(x) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
end
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
i = ((first(inds) - first(o)) * stride)
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
end
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
return (first(inds) - first(o)) * stride
end
# trailing inbounds can only be 1 or 1:1
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)

@propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray})
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
Expand All @@ -194,7 +175,7 @@ end
return arg
end
@propagate_inbounds function to_index(::IndexLinear, x, arg::Integer)
@boundscheck checkindex(Bool, x, arg) || throw(BoundsError(x, arg))
@boundscheck checkindex(Bool, indices(x), arg) || throw(BoundsError(x, arg))
return _int(arg)
end
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractArray{Bool})
Expand All @@ -209,25 +190,11 @@ end
@boundscheck checkindex(Bool, indices(axis), arg) || throw(BoundsError(axis, arg))
return static_first(arg):static_step(arg):static_last(arg)
end
to_index(::IndexLinear, x, inds::Tuple{Any}) = first(inds)
function to_index(::IndexLinear, x, inds::Tuple{Any,Vararg{Any}})
o = offsets(x)
s = size(x)
return first(inds) + (offset1(x) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
end
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
i = ((first(inds) - first(o)) * stride)
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
end
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
return (first(inds) - first(o)) * stride
end
# trailing inbounds can only be 1 or 1:1
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)

## IndexCartesian ##
to_index(::IndexCartesian, x, arg::Colon) = CartesianIndices(x)
to_index(::IndexCartesian, x, arg::CartesianIndices{0}) = arg
to_index(::IndexCartesian, x, arg::AbstractCartesianIndex) = arg
function to_index(::IndexCartesian, x, arg)
@boundscheck _multi_check_index(axes(x), arg) || throw(BoundsError(x, arg))
return arg
Expand All @@ -253,15 +220,13 @@ end
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
end
to_index(::IndexCartesian, x, i::Integer) = _int2subs(axes(x), i - offset1(x))
@inline function _int2subs(axs::Tuple{Any,Vararg{Any}}, i)
axis = first(axs)
len = static_length(axis)
to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - offset1(x)))
@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i)
len = first(s)
inext = div(i, len)
return (_int(i - len * inext + static_first(axis)), _int2subs(tail(axs), inext)...)
return (_int(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...)
end
_int2subs(axs::Tuple{Any}, i) = _int(i + static_first(first(axs)))

_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = _int(i + first(o))

"""
unsafe_reconstruct(A, data; kwargs...)
Expand Down Expand Up @@ -353,6 +318,9 @@ end
end
to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds)

################
### getindex ###
################
"""
ArrayInterface.getindex(A, args...)

Expand All @@ -362,14 +330,19 @@ Changing indexing based on a given argument from `args` should be done through,
[`to_index`](@ref), or [`to_axis`](@ref).
"""
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
@propagate_inbounds getindex(A; kwargs...) = A[order_named_inds(dimnames(A), kwargs.data)...]
@propagate_inbounds function getindex(A; kwargs...)
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)

## unsafe_get_index ##
unsafe_get_index(A, inds::Tuple) = _unsafe_get_index(is_element_index(inds), A, inds)
_unsafe_get_index(::True, A, inds::Tuple) = unsafe_get_element(A, inds)
unsafe_get_index(A, inds::Tuple) = _unsafe_get_index(_is_element_index(inds), A, inds)
_unsafe_get_index(::False, A, inds::Tuple) = unsafe_get_collection(A, inds)
_unsafe_get_index(::True, A, inds::Tuple) = __unsafe_get_index(A, inds)
__unsafe_get_index(A, inds::Tuple{}) = unsafe_get_element(A, ())
__unsafe_get_index(A, inds::Tuple{Any}) = unsafe_get_element(A, first(inds))
__unsafe_get_index(A, inds::Tuple{Any,Vararg{Any}}) = unsafe_get_element(A, NDIndex(inds))

"""
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
Expand All @@ -380,22 +353,30 @@ must define `unsafe_get_element(::NewArrayType, inds)`.
"""
unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds)
_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds)
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds...])
_unsafe_get_element(::False, a::AbstractArray2, inds) = unsafe_get_element_error(a, inds)
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds])
_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i)

## Array ##
unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1)
unsafe_get_element(A::Array, inds) = Base.arrayref(false, A, Int(to_index(A, inds)))
unsafe_get_element(A::LinearIndices, inds) = Int(to_index(A, inds))
@inline function unsafe_get_element(A::CartesianIndices, inds)
if length(inds) === 1
return CartesianIndex(to_index(A, first(inds)))
else
return CartesianIndex(Base._to_subscript_indices(A, inds...))
end
unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i))
unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i))

## LinearIndices ##
unsafe_get_element(A::LinearIndices, i::Integer) = Int(i)
unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i))

unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i)
unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i))

unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i)
function unsafe_get_element(A::ReshapedArray, i::NDIndex)
return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i))
end
unsafe_get_element(A::ReshapedArray, inds) = @inbounds(A[inds...])
unsafe_get_element(A::SubArray, inds) = @inbounds(A[inds...])

unsafe_get_element_error(A, inds) = throw(MethodError(unsafe_get_element, (A, inds)))
unsafe_get_element(A::SubArray, i) = @inbounds(A[i])
function unsafe_get_element_error(@nospecialize(A), @nospecialize(i))
throw(MethodError(unsafe_get_element, (A, i)))
end

# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
"""
Expand Down Expand Up @@ -424,7 +405,7 @@ function _generate_unsafe_get_index!_body(N::Int)
# the optimizer is not clever enough to split the union without it
Dy === nothing && return dest
(idx, state) = Dy
dest[idx] = unsafe_get_element(src, Base.Cartesian.@ntuple($N, j))
dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
Dy = iterate(D, state)
end
return dest
Expand Down Expand Up @@ -453,37 +434,36 @@ end
end
end

#################
### setindex! ###
#################
"""
ArrayInterface.setindex!(A, args...)

Store the given values at the given key or index within a collection.
"""
@propagate_inbounds function setindex!(A, val, args...)
if can_setindex(A)
return unsafe_setindex!(A, val, to_indices(A, args))
return unsafe_set_index!(A, val, to_indices(A, args))
else
error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
end
end
@propagate_inbounds function setindex!(A, val; kwargs...)
if has_dimnames(A)
return setindex!(A, val, order_named_inds(dimnames(A), kwargs.data)...)
else
return unsafe_setindex!(A, val, to_indices(A, ()))
end
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
end

"""
unsafe_setindex!(A, val, inds::Tuple)

Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
bounds-checked. This step of the processing pipeline can be customized by:
"""
unsafe_setindex!(A, val, i::Tuple) = unsafe_setindex!(UnsafeIndex(A, i), A, val, i)
unsafe_setindex!(::UnsafeGetElement, A, val, i::Tuple) = unsafe_set_element!(A, val, i)
unsafe_setindex!(::UnsafeGetCollection, A, v, i::Tuple) = unsafe_set_collection!(A, v, i)
unsafe_set_index!(A, v, inds::Tuple) = _unsafe_set_index!(_is_element_index(inds), A, v, inds)
_unsafe_set_index!(::False, A, v, inds::Tuple) = unsafe_set_collection!(A, v, inds)
_unsafe_set_index!(::True, A, v, inds::Tuple) = __unsafe_set_index!(A, v, inds)
__unsafe_set_index!(A, v, inds::Tuple{}) = unsafe_set_element!(A, v, ())
function __unsafe_set_index!(A, v, inds::Tuple{Any})
return unsafe_set_element!(A, v, to_index(A, first(inds)))
end
function __unsafe_set_index!(A, v, inds::Tuple{Any,Vararg{Any}})
return unsafe_set_element!(A, v, to_index(A, NDIndex(inds)))
end

unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))

"""
unsafe_set_element!(A, val, inds::Tuple)
Expand All @@ -494,19 +474,18 @@ must define `unsafe_set_element!(::NewArrayType, val, inds)`.
"""
unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds)
_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds)
_unsafe_set_element!(::False, a, val,inds) = @inbounds(parent(a)[inds...] = val)
_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val)

function _unsafe_set_element!(::False, a::AbstractArray2, val, inds)
unsafe_set_element_error(a, val, inds)
end
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))

function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T}
if length(inds) === 0
return Base.arrayset(false, A, convert(T, val)::T, 1)
elseif inds isa Tuple{Vararg{Int}}
return Base.arrayset(false, A, convert(T, val)::T, inds...)
else
throw(MethodError(unsafe_set_element!, (A, inds)))
end
function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T}
Base.arrayset(false, A, convert(T, val)::T, 1)
end
function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T}
return Base.arrayset(false, A, convert(T, val)::T, Int(i))
end

# This is based on Base._unsafe_setindex!.
Expand All @@ -529,7 +508,7 @@ function _generate_unsafe_setindex!_body(N::Int)
# the optimizer that it does not need to emit error paths
Xy === nothing && break
(val, state) = Xy
unsafe_set_element!(A, val, Base.Cartesian.@ntuple($N, i))
unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
Xy = iterate(x′, state)
end
A
Expand Down
Loading