Skip to content

Safe co-iteration across an axis for 1+ arrays #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ Returns the parent array that `x` wraps.
Returns `true` if the size of `T` can change, in which case operations
such as `pop!` and `popfirst!` are available for collections of type `T`.

## indices(x[, d])

Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
returned. If any indices are not equal along dimension `d` an error is thrown. A
tuple may be used to specify a different dimension for each array. If `d` is not
specified then indices for visiting each index of `x` is returned.

## ismutable(x)

A trait function for whether `x` is a mutable or immutable array. Used for
Expand Down
57 changes: 18 additions & 39 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using Requires
using LinearAlgebra
using SparseArrays

using Base: OneTo

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)
Expand All @@ -20,8 +22,21 @@ parent_type(::Type{Adjoint{T,S}}) where {T,S} = S
parent_type(::Type{Transpose{T,S}}) where {T,S} = S
parent_type(::Type{Symmetric{T,S}}) where {T,S} = S
parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
parent_type(::Type{Base.Slice{T}}) where {T} = T
parent_type(::Type{T}) where {T} = T

"""
known_length(::Type{T})

If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.
"""
known_length(x) = known_length(typeof(x))
known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T))

"""
can_change_size(::Type{T}) -> Bool

Expand Down Expand Up @@ -503,45 +518,6 @@ function restructure(x::Array,y)
reshape(convert(Array,y),size(x)...)
end

"""
known_first(::Type{T})

If `first` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_first(typeof(1:4)))
@test isone(known_first(typeof(Base.OneTo(4))))
"""
known_first(x) = known_first(typeof(x))
known_first(::Type{T}) where {T} = nothing
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)

"""
known_last(::Type{T})

If `last` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_last(typeof(1:4)))
using StaticArrays
@test known_last(typeof(SOneTo(4))) == 4
"""
known_last(x) = known_last(typeof(x))
known_last(::Type{T}) where {T} = nothing

"""
known_step(::Type{T})

If `step` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_step(typeof(1:0.2:4)))
@test isone(known_step(typeof(1:4)))
"""
known_step(x) = known_step(typeof(x))
known_step(::Type{T}) where {T} = nothing
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

function __init__()

@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down Expand Up @@ -575,6 +551,7 @@ function __init__()

known_first(::Type{<:StaticArrays.SOneTo}) = 1
known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N

@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S
Expand Down Expand Up @@ -697,4 +674,6 @@ function __init__()
end
end

include("ranges.jl")

end
231 changes: 231 additions & 0 deletions src/ranges.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@

"""
known_first(::Type{T})

If `first` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_first(typeof(1:4)))
@test isone(known_first(typeof(Base.OneTo(4))))
"""
known_first(x) = known_first(typeof(x))
known_first(::Type{T}) where {T} = nothing
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))

"""
known_last(::Type{T})

If `last` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_last(typeof(1:4)))
using StaticArrays
@test known_last(typeof(SOneTo(4))) == 4
"""
known_last(x) = known_last(typeof(x))
known_last(::Type{T}) where {T} = nothing
known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T))

"""
known_step(::Type{T})

If `step` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

@test isnothing(known_step(typeof(1:0.2:4)))
@test isone(known_step(typeof(1:4)))
"""
known_step(x) = known_step(typeof(x))
known_step(::Type{T}) where {T} = nothing
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

# add methods to support ArrayInterface

_get(x) = x
_get(::Val{V}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

"""
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}

This range permits diverse representations of arrays to comunicate common information
about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
at compile time. An `OptionallyStaticUnitRange` is intended to be constructed internally
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
start::F
stop::L

function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
if _get(start) isa T
if _get(stop) isa T
return new{T,typeof(start),typeof(stop)}(start, stop)
else
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
end
else
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
end
end

function OptionallyStaticUnitRange(start, stop)
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
return OptionallyStaticUnitRange{T}(start, stop)
end

function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = known_first(x)
fst = fst === nothing ? first(x) : Val(fst)
lst = known_last(x)
lst = lst === nothing ? last(x) : Val(lst)
return OptionallyStaticUnitRange(fst, lst)
else
throw(ArgumentError("step must be 1, got $(step(r))"))
end
end
end

Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start

Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)

Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
return unsafe_isempty_one_to(last(r))
else
return unsafe_isempty_unit_range(first(r), last(r))
end
end

unsafe_isempty_one_to(lst) = lst <= zero(lst)
unsafe_isempty_unit_range(fst, lst) = fst > lst

unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))

unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))

Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
if known_first(r) === oneunit(r)
return get_index_one_to(r, i)
else
return get_index_unit_range(r, i)
end
end

@inline function get_index_one_to(r, i)
@boundscheck if ((i > 0) & (i <= last(r)))
throw(BoundsError(r, i))
end
return convert(eltype(r), i)
end

@inline function get_index_unit_range(r, i)
val = first(r) + (i - 1)
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
throw(BoundsError(r, i))
end
return convert(eltype(r), val)
end

_try_static(x, y) = Val(x)
_try_static(::Nothing, y) = Val(y)
_try_static(x, ::Nothing) = Val(x)
_try_static(::Nothing, ::Nothing) = nothing

###
### length
###
@inline function known_length(::Type{T}) where {T<:AbstractUnitRange}
fst = known_first(T)
lst = known_last(T)
if fst === nothing || lst === nothing
return nothing
else
if fst === oneunit(eltype(T))
return unsafe_length_one_to(lst)
else
return unsafe_length_unit_range(fst, lst)
end
end
end

function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
if isempty(r)
return zero(T)
else
if known_one(r) === one(T)
return unsafe_length_one_to(last(r))
else
return unsafe_length_unit_range(first(r), last(r))
end
end
end

function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
end
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
return Base.checked_add(lst - fst, one(T))
end

"""
indices(x[, d])

Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
returned. If any indices are not equal along dimension `d` an error is thrown. A
tuple may be used to specify a different dimension for each array. If `d` is not
specified then indices for visiting each index of `x` is returned.
"""
@inline function indices(x)
inds = eachindex(x)
if inds isa AbstractUnitRange{<:Integer}
return Base.Slice(OptionallyStaticUnitRange(inds))
else
return inds
end
end

function indices(x::Tuple)
inds = map(eachindex, x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

indices(x, d) = indices(axes(x, d))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This computes the indices of the indices, rather than the indices of x. You're implicitly assuming that axes(x) is idempotent. The older implementations of OffsetArrays would break this, for example, although the community seems to have widely settled on idempotency as a useful characteristic for the axes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of iterating values it should theoretically be the same, but it's not entirely idempotent because indices(x::OneTo) = Slice(OptionallyStaticUnitRange(Val(1), last(x))). If indices(x::OffsetArray, d) = indices(axes(x, d)) doesn't produce something that is appropriately offset then it could do something like indices(x::OffsetArray, d) = indices(range(x.offset[d], stop = size(x, d)).


@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
inds = map(x_i -> indices(x_i, dim), x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
inds = map(indices, x, dim)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function _pick_range(x, y)
fst = _try_static(known_first(x), known_first(y))
fst = fst === nothing ? first(x) : fst

lst = _try_static(known_last(x), known_last(y))
lst = lst === nothing ? last(x) : lst
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

20 changes: 20 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using StaticArrays
@test ArrayInterface.ismutable((0.1,1.0)) == false
@test isone(ArrayInterface.known_first(typeof(StaticArrays.SOneTo(7))))
@test ArrayInterface.known_last(typeof(StaticArrays.SOneTo(7))) == 7
@test ArrayInterface.known_length(typeof(StaticArrays.SOneTo(7))) == 7

using LinearAlgebra, SparseArrays

Expand Down Expand Up @@ -173,6 +174,8 @@ using ArrayInterface: parent_type
@test parent_type(transpose(x)) <: typeof(x)
@test parent_type(Symmetric(x)) <: typeof(x)
@test parent_type(UpperTriangular(x)) <: typeof(x)
@test parent_type(PermutedDimsArray(x, (2,1))) <: typeof(x)
@test parent_type(Base.Slice(1:10)) <: UnitRange{Int}
end

@testset "Range Interface" begin
Expand All @@ -196,3 +199,20 @@ end
@test !ArrayInterface.can_change_size(Tuple{})
end

@testset "known_length" begin
@test ArrayInterface.known_length(ArrayInterface.indices(SOneTo(7))) == 7
@test ArrayInterface.known_length(1:2) == nothing
@test ArrayInterface.known_length((1,)) == 1
@test ArrayInterface.known_length((a=1,b=2)) == 2
end

@testset "indices" begin
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
end