Skip to content

Commit

Permalink
promote ranges to the largest of the start, step, or length (as appli…
Browse files Browse the repository at this point in the history
…cable)

Be careful to use `oneunit` instead of `1`, so that arithmetic on
user-given types does not promote first to Int.

Fixes #35711
Fixes #10554
  • Loading branch information
vtjnash committed Dec 8, 2021
1 parent ff185b7 commit 401fd1e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 29 deletions.
65 changes: 36 additions & 29 deletions base/range.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

(:)(a::Real, b::Real) = (:)(promote(a,b)...)
(:)(a::Real, b::Real) = (:)(promote(a, b)...)

(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)

# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
(:)(convert(promote_type(A,C),start), step, convert(promote_type(A,C),stop))
(:)(start::A, step, stop::C) where {A<:Real, C<:Real} =
(:)(convert(promote_type(A, C), start), step, convert(promote_type(A, C), stop))

# AbstractFloat specializations
(:)(a::T, b::T) where {T<:AbstractFloat} = (:)(a, T(1), b)

(:)(a::T, b::AbstractFloat, c::T) where {T<:Real} = (:)(promote(a,b,c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:AbstractFloat} = (:)(promote(a,b,c)...)
(:)(a::T, b::Real, c::T) where {T<:AbstractFloat} = (:)(promote(a,b,c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:Real} = (:)(promote(a, b, c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:AbstractFloat} = (:)(promote(a, b, c)...)
(:)(a::T, b::Real, c::T) where {T<:AbstractFloat} = (:)(promote(a, b, c)...)

(:)(start::T, step::T, stop::T) where {T<:AbstractFloat} =
_colon(OrderStyle(T), ArithmeticStyle(T), start, step, stop)
Expand Down Expand Up @@ -168,15 +168,15 @@ range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
range_stop(stop::Integer) = range_length(stop)

# Stop and length as the only argument
range_stop_length(a::Real, len::Integer) = UnitRange{typeof(a)}(oftype(a, a-len+1), a)
range_stop_length(a::Real, len::Integer) = UnitRange(promote(a - (len - oneunit(len)), a)...)
range_stop_length(a::AbstractFloat, len::Integer) = range_step_stop_length(oftype(a, 1), a, len)
range_stop_length(a, len::Integer) = range_step_stop_length(oftype(a-a, 1), a, len)
range_stop_length(a, len::Integer) = range_step_stop_length(oftype(a - a, 1), a, len)

range_step_stop_length(step, stop, length) = reverse(range_start_step_length(stop, -step, length))

range_start_length(a::Real, len::Integer) = UnitRange{typeof(a)}(a, oftype(a, a+len-1))
range_start_length(a::Real, len::Integer) = UnitRange(promote(a, a + (len - oneunit(len)))...)
range_start_length(a::AbstractFloat, len::Integer) = range_start_step_length(a, oftype(a, 1), len)
range_start_length(a, len::Integer) = range_start_step_length(a, oftype(a-a, 1), len)
range_start_length(a, len::Integer) = range_start_step_length(a, oftype(a - a, 1), len)

range_start_stop(start, stop) = start:stop

Expand All @@ -201,15 +201,13 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
end

function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
start = a + zero(step)
stop = a + step * (len - 1)
T = typeof(start)
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
stop = a + step * (len - oneunit(len))
T = typeof(stop)
return StepRange{T,typeof(step)}(convert(T, a), step, stop)
end
function _rangestyle(::Any, ::Any, a, step, len::Integer)
start = a + zero(step)
T = typeof(a)
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
stop = a + step * (len - oneunit(len))
return StepRangeLen{typeof(stop),typeof(a),typeof(step)}(a, step, len)
end

range_start_step_stop(start, step, stop) = start:step:stop
Expand Down Expand Up @@ -893,7 +891,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >=
function getindex(v::UnitRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = convert(T, v.start + (i - 1))
val = convert(T, v.start + (i - oneunit(i)))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
end
Expand All @@ -904,7 +902,7 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,
function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = v.start + (i - 1)
val = v.start + (i - oneunit(i))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end
Expand All @@ -919,7 +917,7 @@ end
function getindex(v::AbstractRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
ret = convert(T, first(v) + (i - 1)*step_hp(v))
ret = convert(T, first(v) + (i - oneunit(i))*step_hp(v))
ok = ifelse(step(v) > zero(step(v)),
(ret <= last(v)) & (ret >= first(v)),
(ret <= first(v)) & (ret >= last(v)))
Expand Down Expand Up @@ -949,7 +947,7 @@ end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-1, r.lendiv, r.start, r.stop)
lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)
end

function lerpi(j::Integer, d::Integer, a::T, b::T) where T
Expand All @@ -968,8 +966,10 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
return range(first(s) ? first(r) : last(r), length = last(s))
else
f = first(r)
start = oftype(f, f + first(s)-firstindex(r))
return range(start, length=length(s))
start = oftype(f, f + first(s) - firstindex(r))
len = length(s)
stop = oftype(f, start + (len - oneunit(len)))
return range(start, stop)
end
end

Expand All @@ -984,11 +984,14 @@ function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@boundscheck checkbounds(r, s)

if T === Bool
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length=last(s))
else
f = first(r)
start = oftype(f, f + s.start-firstindex(r))
return range(start, step=step(s), length=length(s))
start = oftype(f, f + s.start - firstindex(r))
st = step(s)
len = length(s)
stop = oftype(f, start + (len - oneunit(len)) * st)
return range(start, stop; step=st)
end
end

Expand All @@ -1011,9 +1014,13 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
return range(start, step=step(r); length=len)
else
f = r.start
fs = first(s)
st = r.step
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
return range(start; step=st*step(s), length=length(s))
start = oftype(f, f + (fs - oneunit(fs)) * st)
st = st * step(s)
len = length(s)
stop = oftype(f, start + (len - oneunit(len)) * st)
return range(start, stop; step=st)
end
end

Expand Down Expand Up @@ -1042,7 +1049,7 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = L(max(min(1 + round(L, (r.offset - first(s))/sstep), last(ind)), first(ind)))
ref = _getindex_hiprec(r, first(s) + (offset-1)*sstep)
ref = _getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
return StepRangeLen{T}(ref, rstep*sstep, len, offset)
end
end
Expand Down
16 changes: 16 additions & 0 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2257,3 +2257,19 @@ let r = Ptr{Cvoid}(20):-UInt(2):Ptr{Cvoid}(10)
@test step(r) === -UInt(2)
@test last(r) === Ptr{Cvoid}(10)
end

# test behavior of wrap-around and promotion of empty ranges (#35711)
@test length(range(0, length=UInt(0))) === UInt(0)
@test !isempty(range(0, length=UInt(0)))
@test length(range(typemax(Int), length=UInt(0))) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0)))
@test length(range(0, length=UInt(0), step=UInt(2))) == typemin(Int) % UInt
@test !isempty(range(0, length=UInt(0), step=UInt(2)))
@test length(range(typemax(Int), length=UInt(0), step=UInt(2))) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0), step=UInt(2)))
@test length(range(typemax(Int), length=UInt(0), step=2)) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0), step=2))
@test length(range(typemax(Int), length=0, step=UInt(2))) === UInt(0)
@test isempty(range(typemax(Int), length=0, step=UInt(2)))

@test length(range(1, length=typemax(Int128))) === typemax(Int128)

0 comments on commit 401fd1e

Please sign in to comment.