Skip to content

Commit a5c4c15

Browse files
committed
Check that axes start with 1 for AbstractRange operations
Now that we have Base.IdentityUnitRange and Base.Slice, we need to be careful about fallbacks that just use `first`, `step`, `stop`-style properties.
1 parent 589b96d commit a5c4c15

File tree

4 files changed

+131
-26
lines changed

4 files changed

+131
-26
lines changed

base/broadcast.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
88
module Broadcast
99

1010
using .Base.Cartesian
11-
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
11+
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, require_one_based_indexing,
1212
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
1313
import .Base: copy, copyto!, axes
1414
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__
@@ -1002,15 +1002,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange) = r
10021002
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen) = r
10031003
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange) = r
10041004

1005-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
1005+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) =
1006+
(require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
10061007
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset)
10071008
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r))
10081009

1009-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
1010-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
1010+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) =
1011+
(require_one_based_indexing(r); range(x + first(r), length=length(r)))
1012+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) =
1013+
(require_one_based_indexing(r); range(first(r) + x, length=length(r)))
10111014
# For #18336 we need to prevent promotion of the step type:
1012-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
1013-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
1015+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) =
1016+
(require_one_based_indexing(r); range(first(r) + x, step=step(r), length=length(r)))
1017+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) =
1018+
(require_one_based_indexing(r); range(x + first(r), step=step(r), length=length(r)))
10141019
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
10151020
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
10161021
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
@@ -1019,9 +1024,12 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRa
10191024
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r))
10201025
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2
10211026

1022-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r))
1023-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r))
1024-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r))
1027+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) =
1028+
(require_one_based_indexing(r); range(first(r)-x, length=length(r)))
1029+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) =
1030+
(require_one_based_indexing(r); range(first(r)-x, step=step(r), length=length(r)))
1031+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) =
1032+
(require_one_based_indexing(r); range(x-first(r), step=-step(r), length=length(r)))
10251033
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
10261034
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
10271035
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =
@@ -1030,22 +1038,26 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
10301038
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
10311039
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2
10321040

1033-
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
1041+
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) =
1042+
(require_one_based_indexing(r); range(x*first(r), step=x*step(r), length=length(r)))
10341043
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
10351044
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
10361045
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
10371046
# separate in case of noncommutative multiplication
1038-
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
1047+
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) =
1048+
(require_one_based_indexing(r); range(first(r)*x, step=step(r)*x, length=length(r)))
10391049
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
10401050
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
10411051
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)
10421052

1043-
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
1053+
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) =
1054+
(require_one_based_indexing(r); range(first(r)/x, step=step(r)/x, length=length(r)))
10441055
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
10451056
StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset)
10461057
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len)
10471058

1048-
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r))
1059+
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) =
1060+
(require_one_based_indexing(r); range(x\first(r), step=x\step(r), length=length(r)))
10491061
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset)
10501062
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len)
10511063

base/range.jl

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ RangeStepStyle(::Type{<:AbstractRange{<:Integer}}) = RangeStepRegular()
142142

143143
convert(::Type{T}, r::AbstractRange) where {T<:AbstractRange} = r isa T ? r : T(r)
144144

145+
AxesStartStyle(::Type{<:AbstractRange}) = AxesStartAny()
146+
AxesStartStyle(r::AbstractRange) = AxesStartStyle(typeof(r))
147+
148+
require_one_based_indexing(r::AbstractRange) = _require_one_based_indexing(AxesStartStyle(r), r)
149+
_require_one_based_indexing(::AxesStartStyle, r) =
150+
!has_offset_axes(r) || throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
151+
_require_one_based_indexing(::AxesStart1, r) = true
152+
145153
## ordinal ranges
146154

147155
"""
@@ -250,6 +258,8 @@ steprange_last_empty(start, step, stop) = start - step
250258

251259
StepRange(start::T, step::S, stop::T) where {T,S} = StepRange{T,S}(start, step, stop)
252260

261+
AxesStartStyle(::Type{<:StepRange}) = AxesStart1()
262+
253263
"""
254264
UnitRange{T<:Real}
255265
@@ -297,6 +307,8 @@ if isdefined(Main, :Base)
297307
end
298308
end
299309

310+
AxesStartStyle(::Type{<:UnitRange}) = AxesStart1()
311+
300312
"""
301313
Base.OneTo(n)
302314
@@ -318,6 +330,8 @@ end
318330
OneTo(stop::T) where {T<:Integer} = OneTo{T}(stop)
319331
OneTo(r::AbstractRange{T}) where {T<:Integer} = OneTo{T}(r)
320332

333+
AxesStartStyle(::Type{<:OneTo}) = AxesStart1()
334+
321335
## Step ranges parameterized by length
322336

323337
"""
@@ -350,6 +364,8 @@ StepRangeLen(ref::R, step::S, len::Integer, offset::Integer = 1) where {R,S} =
350364
StepRangeLen{T}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
351365
StepRangeLen{T,R,S}(ref, step, len, offset)
352366

367+
AxesStartStyle(::Type{<:StepRangeLen}) = AxesStart1()
368+
353369
## range with computed step
354370

355371
"""
@@ -387,6 +403,8 @@ function LinRange(start, stop, len::Integer)
387403
LinRange{T}(start, stop, len)
388404
end
389405

406+
AxesStartStyle(::Type{<:LinRange}) = AxesStart1()
407+
390408
function _range(start::T, ::Nothing, stop::S, len::Integer) where {T,S}
391409
a, b = promote(start, stop)
392410
_range(a, nothing, b, len)
@@ -713,10 +731,14 @@ show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), '
713731
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
714732
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")
715733

734+
range_axes_first_same(r, s) = _range_axes_first_same(AxesStartStyle(r), AxesStartStyle(s), r, s)
735+
_range_axes_first_same(::AxesStart1, ::AxesStart1, r, s) = true
736+
_range_axes_first_same(::AxesStartStyle, ::AxesStartStyle, r, s) = first(axes1(r)) == first(axes1(s))
737+
716738
==(r::T, s::T) where {T<:AbstractRange} =
717-
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
739+
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
718740
==(r::OrdinalRange, s::OrdinalRange) =
719-
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
741+
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
720742
==(r::T, s::T) where {T<:Union{StepRangeLen,LinRange}} =
721743
(first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s))
722744
==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T} =
@@ -727,6 +749,7 @@ function ==(r::AbstractRange, s::AbstractRange)
727749
if lr != length(s)
728750
return false
729751
end
752+
range_axes_first_same(r, s) || return false
730753
yr, ys = iterate(r), iterate(s)
731754
while yr !== nothing
732755
yr[1] == ys[1] || return false
@@ -849,7 +872,7 @@ end
849872

850873
## linear operations on ranges ##
851874

852-
-(r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
875+
-(r::OrdinalRange) = (require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
853876
-(r::StepRangeLen{T,R,S}) where {T,R,S} =
854877
StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset)
855878
-(r::LinRange) = LinRange(-r.start, -r.stop, length(r))
@@ -873,8 +896,10 @@ OneTo{T}(r::OneTo) where {T<:Integer} = OneTo{T}(r.stop)
873896

874897
promote_rule(a::Type{UnitRange{T1}}, ::Type{UR}) where {T1,UR<:AbstractUnitRange} =
875898
promote_rule(a, UnitRange{eltype(UR)})
876-
UnitRange{T}(r::AbstractUnitRange) where {T<:Real} = UnitRange{T}(first(r), last(r))
877-
UnitRange(r::AbstractUnitRange) = UnitRange(first(r), last(r))
899+
UnitRange{T}(r::AbstractUnitRange) where {T<:Real} =
900+
(require_one_based_indexing(r); UnitRange{T}(first(r), last(r)))
901+
UnitRange(r::AbstractUnitRange) =
902+
(require_one_based_indexing(r); UnitRange(first(r), last(r)))
878903

879904
AbstractUnitRange{T}(r::AbstractUnitRange{T}) where {T} = r
880905
AbstractUnitRange{T}(r::UnitRange) where {T} = UnitRange{T}(r)
@@ -889,10 +914,14 @@ StepRange{T1,T2}(r::StepRange{T1,T2}) where {T1,T2} = r
889914

890915
promote_rule(a::Type{StepRange{T1a,T1b}}, ::Type{UR}) where {T1a,T1b,UR<:AbstractUnitRange} =
891916
promote_rule(a, StepRange{eltype(UR), eltype(UR)})
892-
StepRange{T1,T2}(r::AbstractRange) where {T1,T2} =
917+
function StepRange{T1,T2}(r::AbstractRange) where {T1,T2}
918+
require_one_based_indexing(r)
893919
StepRange{T1,T2}(convert(T1, first(r)), convert(T2, step(r)), convert(T1, last(r)))
894-
StepRange(r::AbstractUnitRange{T}) where {T} =
920+
end
921+
function StepRange(r::AbstractUnitRange{T}) where {T}
922+
require_one_based_indexing(r)
895923
StepRange{T,T}(first(r), step(r), last(r))
924+
end
896925
(::Type{StepRange{T1,T2} where T1})(r::AbstractRange) where {T2} = StepRange{eltype(r),T2}(r)
897926

898927
promote_rule(::Type{StepRangeLen{T1,R1,S1}},::Type{StepRangeLen{T2,R2,S2}}) where {T1,T2,R1,R2,S1,S2} =
@@ -908,15 +937,16 @@ StepRangeLen{T}(r::StepRangeLen) where {T} =
908937
promote_rule(a::Type{StepRangeLen{T,R,S}}, ::Type{OR}) where {T,R,S,OR<:AbstractRange} =
909938
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR)})
910939
StepRangeLen{T,R,S}(r::AbstractRange) where {T,R,S} =
911-
StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r))
940+
(require_one_based_indexing(r); StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r)))
912941
StepRangeLen{T}(r::AbstractRange) where {T} =
913-
StepRangeLen(T(first(r)), T(step(r)), length(r))
942+
(require_one_based_indexing(r); StepRangeLen(T(first(r)), T(step(r)), length(r)))
914943
StepRangeLen(r::AbstractRange) = StepRangeLen{eltype(r)}(r)
915944

916945
promote_rule(a::Type{LinRange{T1}}, b::Type{LinRange{T2}}) where {T1,T2} =
917946
el_same(promote_type(T1,T2), a, b)
918947
LinRange{T}(r::LinRange{T}) where {T} = r
919-
LinRange{T}(r::AbstractRange) where {T} = LinRange{T}(first(r), last(r), length(r))
948+
LinRange{T}(r::AbstractRange) where {T} =
949+
(require_one_based_indexing(r); LinRange{T}(first(r), last(r), length(r)))
920950
LinRange(r::AbstractRange{T}) where {T} = LinRange{T}(r)
921951

922952
promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} =
@@ -944,7 +974,10 @@ end
944974
Array{T,1}(r::AbstractRange{T}) where {T} = vcat(r)
945975
collect(r::AbstractRange) = vcat(r)
946976

947-
reverse(r::OrdinalRange) = (:)(last(r), -step(r), first(r))
977+
function reverse(r::OrdinalRange)
978+
require_one_based_indexing(r)
979+
(:)(last(r), -step(r), first(r))
980+
end
948981
function reverse(r::StepRangeLen)
949982
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
950983
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
@@ -964,8 +997,11 @@ sort!(r::AbstractUnitRange) = r
964997

965998
sort(r::AbstractRange) = issorted(r) ? r : reverse(r)
966999

967-
sortperm(r::AbstractUnitRange) = 1:length(r)
968-
sortperm(r::AbstractRange) = issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
1000+
sortperm(r::AbstractUnitRange) = (require_one_based_indexing(r); 1:length(r))
1001+
function sortperm(r::AbstractRange)
1002+
require_one_based_indexing(r)
1003+
issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
1004+
end
9691005

9701006
function sum(r::AbstractRange{<:Real})
9711007
l = length(r)
@@ -1004,6 +1040,7 @@ function _define_range_op(@nospecialize f)
10041040
r1l = length(r1)
10051041
(r1l == length(r2) ||
10061042
throw(DimensionMismatch("argument dimensions must match")))
1043+
require_one_based_indexing(r1, r2)
10071044
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
10081045
end
10091046

base/traits.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,18 @@ struct RangeStepRegular <: RangeStepStyle end # range with regular step
5757
struct RangeStepIrregular <: RangeStepStyle end # range with rounding error
5858

5959
RangeStepStyle(instance) = RangeStepStyle(typeof(instance))
60+
61+
# trait that allows skipping of axes-checking on abstract range types (risks overflow on `length`)
62+
"""
63+
AxesStartStyle(instance)
64+
AxesStartStyle(T::Type)
65+
66+
Indicate the value that `axes(instance)` starts with. Containers that return `AxesStart1()`
67+
must have `axes(instance)` start with 1 (e.g., `Base.OneTo` axes). Such containers may
68+
bypass axes checks for certain operations (e.g., range comparisons to avoid risk of overflow).
69+
`AxesStartAny()` indicates that one cannot count on the axes starting with 1, and that
70+
an explicit check is required.
71+
"""
72+
abstract type AxesStartStyle end
73+
struct AxesStart1 <: AxesStartStyle end
74+
struct AxesStartAny <: AxesStartStyle end

test/ranges.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,3 +1458,44 @@ end
14581458
Base.TwicePrecision(-1.0, -0.0), 0)
14591459
@test reverse(reverse(1.0:0.0)) === 1.0:0.0
14601460
end
1461+
1462+
@testset "Fallbacks for IdentityUnitRange" begin
1463+
r = Base.IdentityUnitRange(-2:2)
1464+
argerr = ArgumentError("offset arrays are not supported but got an array with index other than 1")
1465+
@test r != -2:2
1466+
@test r != -2:1:2
1467+
@test r == r
1468+
@test r != Base.IdentityUnitRange(-1:2)
1469+
@test +r === r
1470+
@test_throws argerr UnitRange{Int}(r)
1471+
@test_throws argerr UnitRange(r)
1472+
@test_throws argerr StepRange{Int,Int}(r)
1473+
@test_throws argerr StepRange(r)
1474+
@test_throws argerr StepRangeLen(r)
1475+
@test_throws argerr StepRangeLen{Int,Int,Int}(r)
1476+
@test_throws argerr LinRange(r)
1477+
@test_throws argerr -r
1478+
@test_throws argerr .-r
1479+
@test_throws argerr r .+ 1
1480+
@test_throws argerr 1 .+ r
1481+
@test_throws argerr r .+ im
1482+
@test_throws argerr im .+ r
1483+
@test_throws argerr r .- 1
1484+
@test_throws argerr 1 .- r
1485+
@test_throws argerr 2 * r
1486+
@test_throws argerr r * 2
1487+
@test_throws argerr 2 .* r
1488+
@test_throws argerr r .* 2
1489+
@test_throws argerr r / 2
1490+
@test_throws argerr r ./ 2
1491+
@test_throws argerr 2 \ r
1492+
@test_throws argerr 2 .\ r
1493+
@test_throws argerr r + r
1494+
@test_throws argerr r - r
1495+
@test_throws argerr r .+ r
1496+
@test_throws argerr r .- r
1497+
@test_throws MethodError r .* r
1498+
@test_throws DimensionMismatch r .* (-2:2)
1499+
@test_throws argerr reverse(r)
1500+
@test_throws argerr sortperm(r)
1501+
end

0 commit comments

Comments
 (0)