Skip to content

Commit 479ac52

Browse files
committed
Test known_length methods & fix indent size
1 parent 14dd886 commit 479ac52

File tree

3 files changed

+68
-59
lines changed

3 files changed

+68
-59
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ returned. If any indices are not equal along dimension `d` an error is thrown. A
2828
tuple may be used to specify a different dimension for each array. If `d` is not
2929
specified then indices for visiting each index of `x` is returned.
3030

31-
3231
## ismutable(x)
3332

3433
A trait function for whether `x` is a mutable or immutable array. Used for

src/ranges.jl

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,25 @@ from other valid indices. Therefore, users should not expect the same checks are
5858
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
5959
"""
6060
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
61-
start::F
62-
stop::L
63-
64-
function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
65-
if _get(start) isa T
66-
if _get(stop) isa T
67-
return new{T,typeof(start),typeof(stop)}(start, stop)
68-
else
69-
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
70-
end
71-
else
72-
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
73-
end
61+
start::F
62+
stop::L
63+
64+
function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
65+
if _get(start) isa T
66+
if _get(stop) isa T
67+
return new{T,typeof(start),typeof(stop)}(start, stop)
68+
else
69+
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
70+
end
71+
else
72+
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
7473
end
74+
end
7575

76-
function OptionallyStaticUnitRange(start, stop)
77-
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
78-
return OptionallyStaticUnitRange{T}(start, stop)
79-
end
76+
function OptionallyStaticUnitRange(start, stop)
77+
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
78+
return OptionallyStaticUnitRange{T}(start, stop)
79+
end
8080
end
8181

8282
Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
@@ -92,11 +92,11 @@ known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
9292
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
9393

9494
function Base.isempty(r::OptionallyStaticUnitRange)
95-
if known_first(r) === oneunit(eltype(r))
96-
return unsafe_isempty_one_to(last(r))
97-
else
98-
return unsafe_isempty_unit_range(first(r), last(r))
99-
end
95+
if known_first(r) === oneunit(eltype(r))
96+
return unsafe_isempty_one_to(last(r))
97+
else
98+
return unsafe_isempty_unit_range(first(r), last(r))
99+
end
100100
end
101101

102102
unsafe_isempty_one_to(lst) = lst <= zero(lst)
@@ -108,26 +108,26 @@ unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
108108
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
109109

110110
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
111-
if known_first(r) === oneunit(r)
112-
return get_index_one_to(r, i)
113-
else
114-
return get_index_unit_range(r, i)
115-
end
111+
if known_first(r) === oneunit(r)
112+
return get_index_one_to(r, i)
113+
else
114+
return get_index_unit_range(r, i)
115+
end
116116
end
117117

118118
@inline function get_index_one_to(r, i)
119-
@boundscheck if ((i > 0) & (i <= last(r)))
120-
throw(BoundsError(r, i))
121-
end
122-
return convert(eltype(r), i)
119+
@boundscheck if ((i > 0) & (i <= last(r)))
120+
throw(BoundsError(r, i))
121+
end
122+
return convert(eltype(r), i)
123123
end
124124

125125
@inline function get_index_unit_range(r, i)
126-
val = first(r) + (i - 1)
127-
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
128-
throw(BoundsError(r, i))
129-
end
130-
return convert(eltype(r), val)
126+
val = first(r) + (i - 1)
127+
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
128+
throw(BoundsError(r, i))
129+
end
130+
return convert(eltype(r), val)
131131
end
132132

133133
_try_static(x, y) = Val(x)
@@ -141,7 +141,7 @@ _try_static(::Nothing, ::Nothing) = nothing
141141
@inline function known_length(::Type{T}) where {T<:AbstractUnitRange}
142142
fst = known_first(T)
143143
lst = known_last(T)
144-
if stp === nothing || fst === nothing || lst === nothing
144+
if fst === nothing || lst === nothing
145145
return nothing
146146
else
147147
if fst === oneunit(eltype(T))
@@ -153,26 +153,26 @@ _try_static(::Nothing, ::Nothing) = nothing
153153
end
154154

155155
function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
156-
if isempty(r)
157-
return zero(T)
156+
if isempty(r)
157+
return zero(T)
158+
else
159+
if known_one(r) === one(T)
160+
return unsafe_length_one_to(last(r))
158161
else
159-
if known_one(r) === one(T)
160-
return unsafe_length_one_to(last(r))
161-
else
162-
return unsafe_length_unit_range(first(r), last(r))
163-
end
162+
return unsafe_length_unit_range(first(r), last(r))
164163
end
164+
end
165165
end
166166

167167
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
168-
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
168+
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
169169
end
170170
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
171-
return Base.checked_add(lst - fst, one(T))
171+
return Base.checked_add(lst - fst, one(T))
172172
end
173173

174174
"""
175-
indices(x[, d]) -> AbstractRange
175+
indices(x[, d])
176176
177177
Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
178178
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
@@ -181,12 +181,12 @@ tuple may be used to specify a different dimension for each array. If `d` is not
181181
specified then indices for visiting each index of `x` is returned.
182182
"""
183183
@inline function indices(x)
184-
inds = eachindex(x)
185-
if inds isa AbstractUnitRange{<:Integer}
186-
return Base.Slice(inds)
187-
else
188-
return inds
189-
end
184+
inds = eachindex(x)
185+
if inds isa AbstractUnitRange{<:Integer}
186+
return Base.Slice(inds)
187+
else
188+
return inds
189+
end
190190
end
191191

192192
indices(x, d) = indices(axes(x, d))
@@ -204,11 +204,11 @@ end
204204
end
205205

206206
@inline function _pick_range(x, y)
207-
fst = _try_static(known_first(x), known_first(y))
208-
fst = fst === nothing ? first(x) : fst
207+
fst = _try_static(known_first(x), known_first(y))
208+
fst = fst === nothing ? first(x) : fst
209209

210-
lst = _try_static(known_last(x), known_last(y))
211-
lst = lst === nothing ? last(x) : lst
212-
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
210+
lst = _try_static(known_last(x), known_last(y))
211+
lst = lst === nothing ? last(x) : lst
212+
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
213213
end
214214

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using StaticArrays
1010
@test ArrayInterface.ismutable((0.1,1.0)) == false
1111
@test isone(ArrayInterface.known_first(typeof(StaticArrays.SOneTo(7))))
1212
@test ArrayInterface.known_last(typeof(StaticArrays.SOneTo(7))) == 7
13+
@test ArrayInterface.known_length(typeof(StaticArrays.SOneTo(7))) == 7
1314

1415
using LinearAlgebra, SparseArrays
1516

@@ -173,6 +174,8 @@ using ArrayInterface: parent_type
173174
@test parent_type(transpose(x)) <: typeof(x)
174175
@test parent_type(Symmetric(x)) <: typeof(x)
175176
@test parent_type(UpperTriangular(x)) <: typeof(x)
177+
@test parent_type(PermutedDimsArray(x, (2,1))) <: typeof(x)
178+
@test parent_type(Base.Slice(1:10)) <: UnitRange{Int}
176179
end
177180

178181
@testset "Range Interface" begin
@@ -196,6 +199,13 @@ end
196199
@test !ArrayInterface.can_change_size(Tuple{})
197200
end
198201

202+
@testset "known_length" begin
203+
@test ArrayInterface.known_length(ArrayInterface.indices(SOneTo(7))) == 7
204+
@test ArrayInterface.known_length(1:2) == nothing
205+
@test ArrayInterface.known_length((1,)) == 1
206+
@test ArrayInterface.known_length((a=1,b=2)) == 2
207+
end
208+
199209
@testset "indices" begin
200210
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
201211
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2

0 commit comments

Comments
 (0)