Skip to content

Commit 95c0102

Browse files
committed
Method for indices(::Tuple)
This basically does the same thing as `eachindex(A1, A2)` but if the resulting iterators are `AbstractUnitRange{<:Integer}` it wraps the result in a `Slice` so that we know the iterator spans the entire dimension.
1 parent 479ac52 commit 95c0102

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/ranges.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ known_first(::Type{T}) where {T} = nothing
1313
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
1414
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))
1515

16-
1716
"""
1817
known_last(::Type{T})
1918
@@ -77,6 +76,18 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
7776
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
7877
return OptionallyStaticUnitRange{T}(start, stop)
7978
end
79+
80+
function OptionallyStaticUnitRange(x::AbstractRange)
81+
if step(x) == 1
82+
fst = known_first(x)
83+
fst = fst === nothing ? first(x) : Val(fst)
84+
lst = known_last(x)
85+
lst = lst === nothing ? last(x) : Val(lst)
86+
return OptionallyStaticUnitRange(fst, lst)
87+
else
88+
throw(ArgumentError("step must be 1, got $(step(r))"))
89+
end
90+
end
8091
end
8192

8293
Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
@@ -183,12 +194,18 @@ specified then indices for visiting each index of `x` is returned.
183194
@inline function indices(x)
184195
inds = eachindex(x)
185196
if inds isa AbstractUnitRange{<:Integer}
186-
return Base.Slice(inds)
197+
return Base.Slice(OptionallyStaticUnitRange(inds))
187198
else
188199
return inds
189200
end
190201
end
191202

203+
function indices(x::Tuple)
204+
inds = map(eachindex, x)
205+
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
206+
return reduce(_pick_range, inds)
207+
end
208+
192209
indices(x, d) = indices(axes(x, d))
193210

194211
@inline function indices(x::NTuple{N,<:Any}, dim) where {N}

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ end
207207
end
208208

209209
@testset "indices" begin
210+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
210211
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
211212
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
212213
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2

0 commit comments

Comments
 (0)