Skip to content

Commit aa3ebdd

Browse files
N5N3vtjnash
authored andcommitted
improve type-based offset axes check (JuliaLang#45260)
* Follow up to JuliaLang#45236 (make `length(::StepRange{Int8,Int128})` type-stable) * Fully drop `_tuple_any` (unneeded now) * Make sure `has_offset_axes(::StepRange)` could be const folded. And define some "cheap" `firstindex` * Do offset axes check on `A`'s parent rather than itself. This avoid some unneeded `axes` call, thus more possible be folded by the compiler. Co-authored-by: Jameson Nash <vtjnash+github@gmail.com>
1 parent b05a843 commit aa3ebdd

File tree

10 files changed

+86
-45
lines changed

10 files changed

+86
-45
lines changed

base/abstractarray.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ If multiple arguments are passed, equivalent to `has_offset_axes(A) | has_offset
104104
105105
See also [`require_one_based_indexing`](@ref).
106106
"""
107-
has_offset_axes(A) = _tuple_any(x->Int(first(x))::Int != 1, axes(A))
107+
has_offset_axes(A) = _any_tuple(x->Int(first(x))::Int != 1, false, axes(A)...)
108108
has_offset_axes(A::AbstractVector) = Int(firstindex(A))::Int != 1 # improve performance of a common case (ranges)
109-
has_offset_axes(A...) = _tuple_any(has_offset_axes, A)
109+
# Use `_any_tuple` to avoid unneeded invoke.
110+
# note: this could call `any` directly if the compiler can infer it
111+
has_offset_axes(As...) = _any_tuple(has_offset_axes, false, As...)
110112
has_offset_axes(::Colon) = false
111113

112114
"""

base/multidimensional.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ module IteratorsMD
335335
# AbstractArray implementation
336336
Base.axes(iter::CartesianIndices{N,R}) where {N,R} = map(Base.axes1, iter.indices)
337337
Base.IndexStyle(::Type{CartesianIndices{N,R}}) where {N,R} = IndexCartesian()
338+
Base.has_offset_axes(iter::CartesianIndices) = Base.has_offset_axes(iter.indices...)
338339
# getindex for a 0D CartesianIndices is necessary for disambiguation
339340
@propagate_inbounds function Base.getindex(iter::CartesianIndices{0,R}) where {R}
340341
CartesianIndex()

base/permuteddimsarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
Base.parent(A::PermutedDimsArray) = A.parent
4949
Base.size(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(size(parent(A)), perm)
5050
Base.axes(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(axes(parent(A)), perm)
51+
Base.has_offset_axes(A::PermutedDimsArray) = Base.has_offset_axes(A.parent)
5152

5253
Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims)
5354

base/range.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,9 @@ step_hp(r::AbstractRange) = step(r)
689689

690690
axes(r::AbstractRange) = (oneto(length(r)),)
691691

692+
# Needed to ensure `has_offset_axes` can constant-fold.
693+
has_offset_axes(::StepRange) = false
694+
692695
# n.b. checked_length for these is defined iff checked_add and checked_sub are
693696
# defined between the relevant types
694697
function checked_length(r::OrdinalRange{T}) where T
@@ -750,64 +753,66 @@ length(r::OneTo) = Integer(r.stop - zero(r.stop))
750753
length(r::StepRangeLen) = r.len
751754
length(r::LinRange) = r.len
752755

753-
let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
754-
global length, checked_length
756+
let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128},
757+
smallints = (Int === Int64 ?
758+
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32} :
759+
Union{Int8, UInt8, Int16, UInt16}),
760+
bitints = Union{bigints, smallints}
761+
global length, checked_length, firstindex
755762
# compile optimization for which promote_type(T, Int) == T
756763
length(r::OneTo{T}) where {T<:bigints} = r.stop
757764
# slightly more accurate length and checked_length in extreme cases
758765
# (near typemax) for types with known `unsigned` functions
759766
function length(r::OrdinalRange{T}) where T<:bigints
760767
s = step(r)
761-
isempty(r) && return zero(T)
762768
diff = last(r) - first(r)
769+
isempty(r) && return zero(diff)
763770
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
764771
# therefore still be valid (if the result is representable at all)
765772
# n.b. !(s isa T)
766773
if s isa Unsigned || -1 <= s <= 1 || s == -s
767-
a = div(diff, s) % T
774+
a = div(diff, s) % typeof(diff)
768775
elseif s < 0
769-
a = div(unsigned(-diff), -s) % T
776+
a = div(unsigned(-diff), -s) % typeof(diff)
770777
else
771-
a = div(unsigned(diff), s) % T
778+
a = div(unsigned(diff), s) % typeof(diff)
772779
end
773-
return a + oneunit(T)
780+
return a + oneunit(a)
774781
end
775782
function checked_length(r::OrdinalRange{T}) where T<:bigints
776783
s = step(r)
777-
isempty(r) && return zero(T)
778784
stop, start = last(r), first(r)
785+
ET = promote_type(typeof(stop), typeof(start))
786+
isempty(r) && return zero(ET)
779787
# n.b. !(s isa T)
780788
if s > 1
781789
diff = stop - start
782-
a = convert(T, div(unsigned(diff), s))
790+
a = convert(ET, div(unsigned(diff), s))
783791
elseif s < -1
784792
diff = start - stop
785-
a = convert(T, div(unsigned(diff), -s))
793+
a = convert(ET, div(unsigned(diff), -s))
786794
elseif s > 0
787-
a = div(checked_sub(stop, start), s)
795+
a = convert(ET, div(checked_sub(stop, start), s))
788796
else
789-
a = div(checked_sub(start, stop), -s)
797+
a = convert(ET, div(checked_sub(start, stop), -s))
790798
end
791-
return checked_add(convert(T, a), oneunit(T))
799+
return checked_add(a, oneunit(a))
792800
end
793-
end
801+
firstindex(r::StepRange{<:bigints,<:bitints}) = one(last(r)-first(r))
794802

795-
# some special cases to favor default Int type
796-
let smallints = (Int === Int64 ?
797-
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32} :
798-
Union{Int8, UInt8, Int16, UInt16})
799-
global length, checked_length
800-
# n.b. !(step isa T)
803+
# some special cases to favor default Int type
801804
function length(r::OrdinalRange{<:smallints})
802805
s = step(r)
803806
isempty(r) && return 0
804-
return div(Int(last(r)) - Int(first(r)), s) + 1
807+
# n.b. !(step isa T)
808+
return Int(div(Int(last(r)) - Int(first(r)), s)) + 1
805809
end
806810
length(r::AbstractUnitRange{<:smallints}) = Int(last(r)) - Int(first(r)) + 1
807811
length(r::OneTo{<:smallints}) = Int(r.stop)
808812
checked_length(r::OrdinalRange{<:smallints}) = length(r)
809813
checked_length(r::AbstractUnitRange{<:smallints}) = length(r)
810814
checked_length(r::OneTo{<:smallints}) = length(r)
815+
firstindex(::StepRange{<:smallints,<:bitints}) = 1
811816
end
812817

813818
first(r::OrdinalRange{T}) where {T} = convert(T, r.start)

base/reinterpretarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
325325
end
326326
axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
327327

328+
has_offset_axes(a::ReinterpretArray) = has_offset_axes(a.parent)
329+
328330
elsize(::Type{<:ReinterpretArray{T}}) where {T} = sizeof(T)
329331
unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))
330332

base/subarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,5 @@ function _indices_sub(i1::AbstractArray, I...)
459459
@inline
460460
(axes(i1)..., _indices_sub(I...)...)
461461
end
462+
463+
has_offset_axes(S::SubArray) = has_offset_axes(S.indices...)

base/tuple.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -581,15 +581,6 @@ any(x::Tuple{Bool}) = x[1]
581581
any(x::Tuple{Bool, Bool}) = x[1]|x[2]
582582
any(x::Tuple{Bool, Bool, Bool}) = x[1]|x[2]|x[3]
583583

584-
# equivalent to any(f, t), to be used only in bootstrap
585-
_tuple_any(f::Function, t::Tuple) = _tuple_any(f, false, t...)
586-
function _tuple_any(f::Function, tf::Bool, a, b...)
587-
@inline
588-
_tuple_any(f, tf | f(a), b...)
589-
end
590-
_tuple_any(f::Function, tf::Bool) = tf
591-
592-
593584
# a version of `in` esp. for NamedTuple, to make it pure, and not compiled for each tuple length
594585
function sym_in(x::Symbol, @nospecialize itr::Tuple{Vararg{Symbol}})
595586
@_total_meta

test/abstractarray.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,11 +1583,11 @@ end
15831583
@test length(rr) == length(r)
15841584
end
15851585

1586-
struct FakeZeroDimArray <: AbstractArray{Int, 0} end
1587-
Base.strides(::FakeZeroDimArray) = ()
1588-
Base.size(::FakeZeroDimArray) = ()
1586+
module IRUtils
1587+
include("compiler/irutils.jl")
1588+
end
1589+
15891590
@testset "strides for ReshapedArray" begin
1590-
# Type-based contiguous check is tested in test/compiler/inline.jl
15911591
function check_strides(A::AbstractArray)
15921592
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
15931593
dims = ntuple(identity, ndims(A))
@@ -1598,6 +1598,10 @@ Base.size(::FakeZeroDimArray) = ()
15981598
end
15991599
return true
16001600
end
1601+
# Type-based contiguous Check
1602+
a = vec(reinterpret(reshape, Int16, reshape(view(reinterpret(Int32, randn(10)), 2:11), 5, :)))
1603+
f(a) = only(strides(a));
1604+
@test IRUtils.fully_eliminated(f, Base.typesof(a)) && f(a) == 1
16011605
# General contiguous check
16021606
a = view(rand(10,10), 1:10, 1:10)
16031607
@test check_strides(vec(a))
@@ -1629,6 +1633,9 @@ Base.size(::FakeZeroDimArray) = ()
16291633
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
16301634
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
16311635
# Zero dimensional parent
1636+
struct FakeZeroDimArray <: AbstractArray{Int, 0} end
1637+
Base.strides(::FakeZeroDimArray) = ()
1638+
Base.size(::FakeZeroDimArray) = ()
16321639
a = reshape(FakeZeroDimArray(),1,1,1)
16331640
@test @inferred(strides(a)) == (1, 1, 1)
16341641
# Dense parent (but not StridedArray)
@@ -1660,3 +1667,21 @@ end
16601667
@test (@inferred A[i,i,i]) === A[1]
16611668
@test (@inferred to_indices([], (1, CIdx(1, 1), 1, CIdx(1, 1), 1, CIdx(1, 1), 1))) == ntuple(Returns(1), 10)
16621669
end
1670+
1671+
@testset "type-based offset axes check" begin
1672+
a = randn(ComplexF64, 10)
1673+
ta = reinterpret(Float64, a)
1674+
tb = reinterpret(Float64, view(a, 1:2:10))
1675+
tc = reinterpret(Float64, reshape(view(a, 1:3:10), 2, 2, 1))
1676+
# Issue #44040
1677+
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(ta, tc))
1678+
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(tc, tc))
1679+
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(ta, tc, tb))
1680+
# Ranges && CartesianIndices
1681+
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(1:10, Base.OneTo(10), 1.0:2.0, LinRange(1.0, 2.0, 2), 1:2:10, CartesianIndices((1:2:10, 1:2:10))))
1682+
# Remind us to call `any` in `Base.has_offset_axes` once our compiler is ready.
1683+
@inline _has_offset_axes(A) = @inline any(x -> Int(first(x))::Int != 1, axes(A))
1684+
@inline _has_offset_axes(As...) = @inline any(_has_offset_axes, As)
1685+
a, b = zeros(2, 2, 2), zeros(2, 2)
1686+
@test_broken IRUtils.fully_eliminated(_has_offset_axes, Base.typesof(a, a, b, b))
1687+
end

test/compiler/inline.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -988,13 +988,6 @@ end
988988
@invoke conditional_escape!(false::Any, x::Any)
989989
end
990990

991-
@testset "strides for ReshapedArray (PR#44027)" begin
992-
# Type-based contiguous check
993-
a = vec(reinterpret(reshape,Int16,reshape(view(reinterpret(Int32,randn(10)),2:11),5,:)))
994-
f(a) = only(strides(a));
995-
@test fully_eliminated(f, Tuple{typeof(a)}) && f(a) == 1
996-
end
997-
998991
@testset "elimination of `get_binding_type`" begin
999992
m = Module()
1000993
@eval m begin

test/ranges.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,8 +2031,17 @@ end
20312031
end
20322032

20332033
@testset "length(StepRange()) type stability" begin
2034-
typeof(length(StepRange(1,Int128(1),1))) == typeof(length(StepRange(1,Int128(1),0)))
2035-
typeof(checked_length(StepRange(1,Int128(1),1))) == typeof(checked_length(StepRange(1,Int128(1),0)))
2034+
for SR in (StepRange{Int,Int128}, StepRange{Int8,Int128})
2035+
r1, r2 = SR(1, 1, 1), SR(1, 1, 0)
2036+
@test typeof(length(r1)) == typeof(checked_length(r1)) ==
2037+
typeof(length(r2)) == typeof(checked_length(r2))
2038+
end
2039+
SR = StepRange{Union{Int64,Int128},Int}
2040+
test_length(r, l) = length(r) === checked_length(r) === l
2041+
@test test_length(SR(Int64(1), 1, Int128(1)), Int128(1))
2042+
@test test_length(SR(Int64(1), 1, Int128(0)), Int128(0))
2043+
@test test_length(SR(Int64(1), 1, Int64(1)), Int64(1))
2044+
@test test_length(SR(Int64(1), 1, Int64(0)), Int64(0))
20362045
end
20372046

20382047
@testset "LinRange eltype for element types that wrap integers" begin
@@ -2346,3 +2355,13 @@ end
23462355
@test isempty(range(typemax(Int), length=0, step=UInt(2)))
23472356

23482357
@test length(range(1, length=typemax(Int128))) === typemax(Int128)
2358+
2359+
@testset "firstindex(::StepRange{<:Base.BitInteger})" begin
2360+
test_firstindex(x) = firstindex(x) === first(Base.axes1(x))
2361+
for T in Base.BitInteger_types, S in Base.BitInteger_types
2362+
@test test_firstindex(StepRange{T,S}(1, 1, 1))
2363+
@test test_firstindex(StepRange{T,S}(1, 1, 0))
2364+
end
2365+
@test test_firstindex(StepRange{Union{Int64,Int128},Int}(Int64(1), 1, Int128(1)))
2366+
@test test_firstindex(StepRange{Union{Int64,Int128},Int}(Int64(1), 1, Int128(0)))
2367+
end

0 commit comments

Comments
 (0)