Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix range implementation on Julia master and resolve method ambiguities #514

Merged
merged 19 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 84 additions & 40 deletions src/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,79 @@ import Base.Broadcast: DefaultArrayStyle, broadcasted
*(y::Units, r::AbstractRange) = *(r,y)
*(r::AbstractRange, y::Units, z::Units...) = *(r, *(y,z...))

Base._range(start::Quantity{<:Real}, ::Nothing, stop, len::Integer) =
_range(promote(start, stop)..., len)
Base._range(start, ::Nothing, stop::Quantity{<:Real}, len::Integer) =
_range(promote(start, stop)..., len)
Base._range(start::Quantity{<:Real}, ::Nothing, stop::Quantity{<:Real}, len::Integer) =
_range(promote(start, stop)..., len)
(Base._range(start::T, ::Nothing, stop::T, len::Integer) where (T<:Quantity{<:Real})) =
LinRange{T}(start, stop, len)
(Base._range(start::T, ::Nothing, stop::T, len::Integer) where (T<:Quantity{<:Integer})) =
Base._linspace(Float64, ustrip(start), ustrip(stop), len, 1)*unit(T)
function Base._range(start::T, ::Nothing, stop::T, len::Integer) where (T<:Quantity{S}
where S<:Union{Float16,Float32,Float64})
range(ustrip(start), stop=ustrip(stop), length=len) * unit(T)
end
function _range(start::Quantity{T}, stop::Quantity{T}, len::Integer) where {T}
# start, stop, length
Base._range(start::Quantity, ::Nothing, stop, len::Integer) =
_unitful_start_stop_length(start, stop, len)
Base._range(start, ::Nothing, stop::Quantity, len::Integer) =
_unitful_start_stop_length(start, stop, len)
Base._range(start::Quantity, ::Nothing, stop::Quantity, len::Integer) =
_unitful_start_stop_length(start, stop, len)
function _unitful_start_stop_length(start, stop, len)
dimension(start) != dimension(stop) && throw(DimensionError(start, stop))
Base._range(start, nothing, stop, len)
a, b = promote(start, stop)
Base._range(a, nothing, b, len)
end
function Base._range(a::T, st::T, ::Nothing, len::Integer) where (T<:Quantity{S}
where S<:Union{Float16,Float32,Float64})
return Base._range(ustrip(a), ustrip(st), nothing, len) * unit(T)
Base._range(start::T, ::Nothing, stop::T, len::Integer) where {T<:Quantity} =
LinRange{T}(start, stop, len)
Base._range(start::T, ::Nothing, stop::T, len::Integer) where {T<:Quantity{<:Integer}} =
Base._linspace(Float64, ustrip(start), ustrip(stop), len, 1)*unit(T)
Base._range(start::T, ::Nothing, stop::T, len::Integer) where {T<:Quantity{<:Base.IEEEFloat}} =
Base._range(ustrip(start), nothing, ustrip(stop), len) * unit(T)

# start, step, length
Base._range(a::T, step::T, ::Nothing, len::Integer) where {T<:Quantity{<:Base.IEEEFloat}} =
Base._range(ustrip(a), ustrip(step), nothing, len) * unit(T)
Base._range(a::T, step::T, ::Nothing, len::Integer) where {T<:Quantity{<:AbstractFloat}} =
StepRangeLen{typeof(step*len),typeof(a),typeof(step)}(a, step, len)
Base._range(a::T, step::T, ::Nothing, len::Integer) where {T<:Quantity} =
@static if VERSION ≥ v"1.8.0-DEV"
Base.range_start_step_length(a, step, len)
else
Base._rangestyle(OrderStyle(a), ArithmeticStyle(a), a, step, len)
end
Base._range(a::Quantity{<:Real}, step::Quantity{<:AbstractFloat}, ::Nothing, len::Integer) =
_unitful_start_step_length(float(a), step, len)
Base._range(a::Quantity{<:AbstractFloat}, step::Quantity{<:Real}, ::Nothing, len::Integer) =
_unitful_start_step_length(a, float(step), len)
Base._range(a::Quantity{<:AbstractFloat}, step::Quantity{<:AbstractFloat}, ::Nothing, len::Integer) =
_unitful_start_step_length(a, step, len)
Base._range(a, step::Quantity, ::Nothing, len::Integer) =
_unitful_start_step_length(a, step, len)
Base._range(a::Quantity, step, ::Nothing, len::Integer) =
_unitful_start_step_length(a, step, len)
Base._range(a::Quantity, step::Quantity, ::Nothing, len::Integer) =
_unitful_start_step_length(a, step, len)
function _unitful_start_step_length(start, step, len)
dimension(start) != dimension(step) && throw(DimensionError(start,step))
Base._range(promote(start, uconvert(unit(start), step))..., nothing, len)
end
Base._range(a::Quantity{<:Real}, st::Quantity{<:AbstractFloat}, ::Nothing, len::Integer) =
Base._range(float(a), st, nothing, len)
Base._range(a::Quantity{<:AbstractFloat}, st::Quantity{<:Real}, ::Nothing, len::Integer) =
Base._range(a, float(st), nothing, len)
function Base._range(a::Quantity{<:AbstractFloat}, st::Quantity{<:AbstractFloat}, ::Nothing, len::Integer)
dimension(a) != dimension(st) && throw(DimensionError(a, st))
Base._range(promote(a, uconvert(unit(a), st))..., nothing, len)

# start, length (step defaults to 1)
Base._range(a::Quantity, ::Nothing, ::Nothing, len::Integer) =
Base._range(a, one(a), nothing, len)

# step, stop, length
@static if VERSION ≥ v"1.7"
Base._range(::Nothing, step, stop::Quantity, len::Integer) =
_unitful_step_stop_length(step, stop, len)
Base._range(::Nothing, step::Quantity, stop, len::Integer) =
_unitful_step_stop_length(step, stop, len)
Base._range(::Nothing, step::Quantity, stop::Quantity, len::Integer) =
_unitful_step_stop_length(step, stop, len)
Base._range(::Nothing, step::Quantity, ::Nothing, len::Integer) =
Base.range_error(nothing, step, nothing, len)
function _unitful_step_stop_length(step, stop, len)
dimension(stop) != dimension(step) && throw(DimensionError(stop,step))
Base.range_step_stop_length(promote(uconvert(unit(stop), step), stop)..., len)
end
end
Base._range(a::Quantity, st::Real, ::Nothing, len::Integer) =
Base._range(promote(a, uconvert(unit(a), st))..., nothing, len)
Base._range(a::Real, st::Quantity, ::Nothing, len::Integer) =
Base._range(promote(a, uconvert(unit(a), st))..., nothing, len)
# the following is needed to give sane error messages when doing e.g. range(1°, 2V, 5)
function Base._range(a::Quantity, step, ::Nothing, len::Integer)
dimension(a) != dimension(step) && throw(DimensionError(a,step))
_a, _step = promote(a, uconvert(unit(a), step))
return Base._rangestyle(OrderStyle(_a), ArithmeticStyle(_a), _a, _step, len)

# stop, length (step defaults to 1)
@static if VERSION ≥ v"1.7"
Base._range(::Nothing, ::Nothing, stop::Quantity, len::Integer) =
Base._range(nothing, one(stop), stop, len)
end

*(r::AbstractRange, y::Units) = range(first(r)*y, step=step(r)*y, length=length(r))

# first promote start and stop, leaving step alone
Expand All @@ -70,7 +103,7 @@ end
OrderStyle(::Type{<:AbstractQuantity{T}}) where T = OrderStyle(T)
ArithmeticStyle(::Type{<:AbstractQuantity{T}}) where T = ArithmeticStyle(T)

(colon(start::T, step::T, stop::T) where T <: Quantity{<:Real}) =
colon(start::T, step::T, stop::T) where {T<:Quantity{<:Real}} =
_colon(OrderStyle(T), ArithmeticStyle(T), start, step, stop)
_colon(::Ordered, ::Any, start::T, step, stop::T) where {T} = StepRange(start, step, stop)
_colon(::Ordered, ::ArithmeticRounds, start::T, step, stop::T) where {T} =
Expand All @@ -83,10 +116,20 @@ _colon(::Any, ::Any, start::T, step, stop::T) where {T} =
*(x::Base.TwicePrecision, y::Quantity) = (x * ustrip(y)) * unit(y)
uconvert(y, x::Base.TwicePrecision) = Base.TwicePrecision(uconvert(y, x.hi), uconvert(y, x.lo))

function colon(start::T, step::T, stop::T) where (T<:Quantity{S}
where S<:Union{Float16,Float32,Float64})
# This will always return a StepRangeLen
return colon(ustrip(start), ustrip(step), ustrip(stop)) * unit(T)
colon(start::T, step::T, stop::T) where {T<:Quantity{<:Base.IEEEFloat}} =
colon(ustrip(start), ustrip(step), ustrip(stop)) * unit(T) # This will always return a StepRangeLen

# two-argument colon
colon(start, stop::Quantity) = _unitful_start_stop(start, stop)
colon(start::Quantity, stop) = _unitful_start_stop(start, stop)
colon(start::Quantity, stop::Quantity) = _unitful_start_stop(start, stop)
function _unitful_start_stop(start, stop)
dimension(start) != dimension(stop) && throw(DimensionError(start, stop))
colon(promote(start, stop)...)
end
function colon(start::T, stop::T) where {T<:Quantity}
step = uconvert(unit(start), one(start))
colon(promote(start, step, stop)...)
end

# No need to confuse things by changing the type once units are on there,
Expand All @@ -95,6 +138,7 @@ end
StepRangeLen{typeof(zero(eltype(r))*y)}(r.ref*y, r.step*y, length(r), r.offset)
*(r::LinRange, y::Units) = LinRange(r.start*y, r.stop*y, length(r))
*(r::StepRange, y::Units) = StepRange(r.start*y, r.step*y, r.stop*y)
*(r::AbstractUnitRange, y::Units) = StepRange(first(r)*y, oneunit(first(r))*y, last(r)*y)
function /(x::Base.TwicePrecision, v::Quantity)
x / Base.TwicePrecision(oftype(ustrip(x.hi)/ustrip(v)*unit(v), v))
end
Expand Down
67 changes: 67 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ end
@test (-2.0Hz:1.0Hz:2.0Hz)/1.0Hz == -2.0:1.0:2.0 # issue 160
@test (range(0, stop=2, length=5) * u"°")[2:end] ==
range(0.5, stop=2, length=4) * u"°" # issue 241
@test range(big(1.0)m, step=big(1.0)m, length=5) == (big(1.0):big(1.0):big(5.0))*m
end
@testset ">> LinSpace" begin
# Not using Compat.range for these because kw args don't infer in julia 0.6.2
Expand Down Expand Up @@ -1207,6 +1208,72 @@ end
@test @inferred((1:2:5) .* cm .|> mm .|> ustrip) === 10:20:50
@test @inferred((1f0:2f0:5f0) .* cm .|> mm .|> ustrip) === 10f0:20f0:50f0
end
@testset ">> quantities and non-quantities" begin
@test range(1, step=1m/mm, length=5) == 1:1000:4001
@test range(1, step=1mm/m, length=5) == (1//1):(1//1000):(251//250)
@test eltype(range(1, step=1m/mm, length=5)) == Int
@test eltype(range(1, step=1mm/m, length=5)) == Rational{Int}
@test range(1m/mm, step=1, length=5) == ((1//1):(1//1000):(251//250)) * m/mm
@test range(1mm/m, step=1, length=5) == (1:1000:4001) * mm/m
@test eltype(range(1m/mm, step=1, length=5)) == typeof((1//1)m/mm)
@test eltype(range(1mm/m, step=1, length=5)) == typeof(1mm/m)
end
@testset ">> complex" begin
@test range((1+2im)m, step=(1+2im)m, length=5) == range(1+2im, step=1+2im, length=5) * m
@test range((1+2im)m, step=(1+2im)mm, length=5) == range(1//1+(2//1)im, step=1//1000+(1//500)im, length=5) * m
@test range((1.0+2.0im)m, stop=(3.0+4.0im)m, length=5) == LinRange(1.0+2.0im, 3.0+4.0im, 5) * m
@test range((1.0+2.0im)mm, stop=(3.0+4.0im)m, length=3) == LinRange(0.001+0.002im, 3.0+4.0im, 3) * m
end
@testset ">> step defaults to 1" begin
@test range(1.0mm/m, length=5) == (1.0mm/m):(1000.0mm/m):(4001.0mm/m)
@test range((1+2im)mm/m, length=5) == range(1+2im, step=1000, length=5)*mm/m
@test_throws DimensionError range(1.0m, length=5)
@test_throws DimensionError range((1+2im)m, length=5)
@test (1mm/m):(5001mm/m) == (1:1000:5001) * mm/m
@test (1m/mm):(5m/mm) == (1//1:1//1000:5//1) * m/mm
@test (1mm/m):(1m/mm) == 1//1000:999001//1000
@test (1m/mm):(1mm/m) == 1000//1:999//1
@test (1.0mm/m):(5001mm/m) == (1.0:1000.0:5001.0) * mm/m
@test (1m/mm):(5.0m/mm) == (1.0:0.001:5.0) * m/mm
@test (1.0mm/m):(1m/mm) == 0.001:999.001
@test (1m/mm):(1.0mm/m) == 1000.0:1.0:999.0
@test_throws DimensionError (1m):(1m)
@test_throws DimensionError (1m):(1000cm)
@test_throws DimensionError (1m):(1s)
@test (1m/cm):1 == 100:99
@test (1m/cm):1000 == 100:1000
@test (1m/cm):1.0 == 100.0:99.0
@test (1.0m/cm):1000 == 100.0:1000.0
@test_throws DimensionError (1m):1
@test 1:(1m/mm) == 1:1000
@test 1000:(1m/mm) == 1000:1000
@test 1.0:(1m/mm) == 1.0:1000.0
@test 1000:(1.0m/mm) == 1000.0:1000.0
@test_throws DimensionError 1:(1m)
end
@static if VERSION ≥ v"1.7"
@testset ">> no start argument" begin
@test range(stop=1.0m, step=2.0m, length=5) == -7.0m:2.0m:1.0m
@test range(stop=1.0mm, step=1.0m, length=5) == -3999.0mm:1000.0mm:1.0mm
@test range(stop=(1.0+2.0im)mm, step=(1.0+1.0im)m, length=5) == range(stop=1.0+2.0im, step=(1000+1000im), length=5)*mm
@test range(stop=1.0mm/m, length=5) == (-3999.0mm/m):(1000.0mm/m):(1.0mm/m)
@test range(stop=(1+2im)mm/m, length=5) == range(stop=1+2im, step=1000, length=5)*mm/m
@test range(stop=1.0mm/m, step=1, length=5) == (-3999.0mm/m):(1000.0mm/m):(1.0mm/m)
@test_throws DimensionError range(stop=1.0m, step=1V, length=5)
@test_throws DimensionError range(stop=(1+2im)m, step=1V, length=5)
@test_throws DimensionError range(stop=1.0m, length=5)
@test_throws DimensionError range(stop=(1+2im)m, length=5)
@test range(stop=1, step=1m/mm, length=5) == -3999:1000:1
@test range(stop=1, step=1mm/m, length=5) == (249//250):(1//1000):(1//1)
@test eltype(range(stop=1, step=1m/mm, length=5)) == Int
@test eltype(range(stop=1, step=1mm/m, length=5)) == Rational{Int}
@test range(stop=1m/mm, step=1, length=5) == ((249//250):(1//1000):(1//1)) * m/mm
@test range(stop=1mm/m, step=1, length=5) == (-3999:1000:1) * mm/m
@test eltype(range(stop=1m/mm, step=1, length=5)) == typeof((1//1)m/mm)
@test eltype(range(stop=1mm/m, step=1, length=5)) == typeof(1mm/m)
@test_throws ArgumentError range(step=1m, length=5)
end
end
end
@testset "> Arrays" begin
@testset ">> Array multiplication" begin
Expand Down