Skip to content

Commit 75cb2a5

Browse files
authored
eliminate the dead iterate branch in _unsafe_(get)setindex!. (#52809)
It's sad that compiler can't do this automatically. Some benchmark with `setindex!`: ```julia julia> a = zeros(Int, 100, 100); julia> @Btime $a[:,:] = $(1:10000); 1.340 μs (0 allocations: 0 bytes) #master: 3.350 μs (0 allocations: 0 bytes) julia> @Btime $a[:,:] = $(view(LinearIndices(a), 1:100, 1:100)); 10.000 μs (0 allocations: 0 bytes) #master: 11.000 μs (0 allocations: 0 bytes) ``` BTW optimization for `FastSubArray` introduced in #45371 still work after this change as the parent array might have their own `copyto!` optimization.
1 parent e8bf9bc commit 75cb2a5

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

base/multidimensional.jl

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,28 @@ _maybe_linear_logical_index(::IndexLinear, A, i) = LogicalIndex{Int}(i)
879879
uncolon(::Tuple{}) = Slice(OneTo(1))
880880
uncolon(inds::Tuple) = Slice(inds[1])
881881

882+
"""
883+
_prechecked_iterate(iter[, state])
884+
885+
Internal function used to eliminate the dead branch in `iterate`.
886+
Fallback to `iterate` by default, but optimized for indices type in `Base`.
887+
"""
888+
@propagate_inbounds _prechecked_iterate(iter) = iterate(iter)
889+
@propagate_inbounds _prechecked_iterate(iter, state) = iterate(iter, state)
890+
891+
_prechecked_iterate(iter::AbstractUnitRange, i = first(iter)) = i, convert(eltype(iter), i + step(iter))
892+
_prechecked_iterate(iter::LinearIndices, i = first(iter)) = i, i + 1
893+
_prechecked_iterate(iter::CartesianIndices) = first(iter), first(iter)
894+
function _prechecked_iterate(iter::CartesianIndices, i)
895+
i′ = IteratorsMD.inc(i.I, iter.indices)
896+
return i′, i′
897+
end
898+
_prechecked_iterate(iter::SCartesianIndices2) = first(iter), first(iter)
899+
function _prechecked_iterate(iter::SCartesianIndices2{K}, (;i, j)) where {K}
900+
I = i < K ? SCartesianIndex2{K}(i + 1, j) : SCartesianIndex2{K}(1, j + 1)
901+
return I, I
902+
end
903+
882904
### From abstractarray.jl: Internal multidimensional indexing definitions ###
883905
getindex(x::Union{Number,AbstractChar}, ::CartesianIndex{0}) = x
884906
getindex(t::Tuple, i::CartesianIndex{1}) = getindex(t, i.I[1])
@@ -910,14 +932,11 @@ function _generate_unsafe_getindex!_body(N::Int)
910932
quote
911933
@inline
912934
D = eachindex(dest)
913-
Dy = iterate(D)
935+
Dy = _prechecked_iterate(D)
914936
@inbounds @nloops $N j d->I[d] begin
915-
# This condition is never hit, but at the moment
916-
# the optimizer is not clever enough to split the union without it
917-
Dy === nothing && return dest
918-
(idx, state) = Dy
937+
(idx, state) = Dy::NTuple{2,Any}
919938
dest[idx] = @ncall $N getindex src j
920-
Dy = iterate(D, state)
939+
Dy = _prechecked_iterate(D, state)
921940
end
922941
return dest
923942
end
@@ -953,14 +972,12 @@ function _generate_unsafe_setindex!_body(N::Int)
953972
@nexprs $N d->(I_d = unalias(A, I[d]))
954973
idxlens = @ncall $N index_lengths I
955974
@ncall $N setindex_shape_check x′ (d->idxlens[d])
956-
Xy = iterate(x′)
975+
X = eachindex(x′)
976+
Xy = _prechecked_iterate(X)
957977
@inbounds @nloops $N i d->I_d begin
958-
# This is never reached, but serves as an assumption for
959-
# the optimizer that it does not need to emit error paths
960-
Xy === nothing && break
961-
(val, state) = Xy
962-
@ncall $N setindex! A val i
963-
Xy = iterate(x′, state)
978+
(idx, state) = Xy::NTuple{2,Any}
979+
@ncall $N setindex! A x′[idx] i
980+
Xy = _prechecked_iterate(X, state)
964981
end
965982
A
966983
end

test/abstractarray.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,3 +1961,21 @@ end
19611961
@test zero([[2,2], [3,3,3]]) isa Vector{Vector{Int}}
19621962
@test zero([[2,2], [3,3,3]]) == [[0,0], [0, 0, 0]]
19631963
end
1964+
1965+
@testset "`_prechecked_iterate` optimization" begin
1966+
function test_prechecked_iterate(iter)
1967+
Js = Base._prechecked_iterate(iter)
1968+
for I in iter
1969+
J, s = Js::NTuple{2,Any}
1970+
@test J === I
1971+
Js = Base._prechecked_iterate(iter, s)
1972+
end
1973+
end
1974+
test_prechecked_iterate(1:10)
1975+
test_prechecked_iterate(Base.OneTo(10))
1976+
test_prechecked_iterate(CartesianIndices((3, 3)))
1977+
test_prechecked_iterate(CartesianIndices(()))
1978+
test_prechecked_iterate(LinearIndices((3, 3)))
1979+
test_prechecked_iterate(LinearIndices(()))
1980+
test_prechecked_iterate(Base.SCartesianIndices2{3}(1:3))
1981+
end

0 commit comments

Comments
 (0)