Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shared pre-computation for by and perm orderings. #52033

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 92 additions & 9 deletions base/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,100 @@ ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data)
lt(o::Ordering, a, b) -> Bool

Test whether `a` is less than `b` according to the ordering `o`.
""" # No see-also because the prepared ordering system is experimental.
function lt end

"""
lt(o::ForwardOrdering, a, b) = isless(a,b)
lt(o::ReverseOrdering, a, b) = lt(o.fwd,b,a)
lt(o::By, a, b) = lt(o.order,o.by(a),o.by(b))
lt(o::Lt, a, b) = o.lt(a,b)
lt_prepared(o::Ordering, a, b)

@propagate_inbounds function lt(p::Perm, a::Integer, b::Integer)
da = p.data[a]
db = p.data[b]
(lt(p.order, da, db)::Bool) | (!(lt(p.order, db, da)::Bool) & (a < b))
end
Test whether `a` is less than `b` according to the ordering `o`, assuming both `a` and `b`
have been prepared with `prepare`.

`lt_prepared(o, prepare(o, a), prepare(o, b))` is equivalent to `lt(o, a, b)`.

!!! warning
Comparing a prepared element `prepare(o1, x)` under a different ordering `o2`
is undefined behavior and may, for example, result in segfaults.

See also `lt_prepared_1`, `lt_prepared_2`.
"""
function lt_prepared end

"""
lt_prepared_1(o::Ordering, a, b)

Test whether `a` is less than `b` according to the ordering `o`, assuming `a` has been
prepared with `prepare`.

`lt_prepared_1(o, prepare(o, a), b)` is equivalent to `lt(o, a, b)`.

!!! warning
Comparing a prepared element `prepare(o1, x)` under a different ordering `o2`
is undefined behavior and may, for example, result in segfaults.

See also `lt`, `lt_prepared`.
"""
@propagate_inbounds lt_prepared_1(o::Ordering, a, b) = lt_prepared(o, a, prepare(o, b))

"""
lt_prepared_2(o::Ordering, a, b)

Test whether `a` is less than `b` according to the ordering `o`, assuming `b` has been
prepared with `prepare`.

!!! warning
Comparing a prepared element `prepare(o1, x)` under a different ordering `o2`
is undefined behavior and may, for example, result in segfaults.

See also `lt`, `lt_prepared`.
"""
@propagate_inbounds lt_prepared_2(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), b)

"""
prepare(o::Ordering, x)

Prepare an element `x` for efficient comparison with `lt_prepared`.

`lt(o::MyOrdering, a, b)` and `lt_prepared(o, prepare(o, a), prepare(o, b))` are
equivalent. They must have indistinguishable behavior and have the same performance
characteristics.

If you define `prepare` on a custom `Ordering`, you should also define `lt_prepared` and
should not define `lt` for that order.

!!! warning
Comparing a prepared element `prepare(o1, x)` under a different ordering `o2`
is undefined behavior and may, for example, result in segfaults.
"""
function prepare end

# Fallbacks
@propagate_inbounds lt_prepared(o::Ordering, a, b) = lt(o, a, b) # TODO: remove this in Julia 2.0
prepare(o::Ordering, x) = x
# Not defining this because it would cause a stack overflow for invalid `Ordering`s:
# lt(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b))

# Forward
lt(o::ForwardOrdering, a, b) = isless(a, b)

# Reverse
prepare(o::ReverseOrdering, x) = prepare(o.fwd, x)
lt_prepared(o::ReverseOrdering, a, b) = lt_prepared(o.fwd, b, a)
lt(o::ReverseOrdering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b))

# By
prepare(o::By, x) = prepare(o.order, o.by(x))
lt_prepared(o::By, a, b) = lt_prepared(o.order, a, b)
lt(o::By, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b))

# Perm
@propagate_inbounds prepare(o::Perm, i) = (prepare(o.order, o.data[i]), i)
lt_prepared(p::Perm, (da, a), (db, b)) =
(lt_prepared(p.order, da, db)::Bool) | (!(lt_prepared(p.order, db, da)::Bool) & (a < b))
@propagate_inbounds lt(o::Perm, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b))

## Lt
lt(o::Lt, a, b) = o.lt(a, b)


_ord(lt::typeof(isless), by, order::Ordering) = _by(by, order)
Expand Down
34 changes: 23 additions & 11 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module Sort

using Base.Order

using Base.Order: prepare, lt_prepared, lt_prepared_1, lt_prepared_2

using Base: copymutable, midpoint, require_one_based_indexing, uinttype,
sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit

Expand Down Expand Up @@ -51,11 +53,13 @@ function issorted(itr, order::Ordering)
y = iterate(itr)
y === nothing && return true
prev, state = y
prev_p = prepare(order, prev)
y = iterate(itr, state)
while y !== nothing
this, state = y
lt(order, this, prev) && return false
prev = this
this_p = prepare(order, this)
lt_prepared(order, this_p, prev_p) && return false
prev_p = this_p
y = iterate(itr, state)
end
return true
Expand Down Expand Up @@ -179,10 +183,11 @@ partialsort(v::AbstractVector, k::Union{Integer,OrdinalRange}; kws...) =
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
hi = hi + T(1)
len = hi - lo
x_p = prepare(o, x)
@inbounds while len != 0
half_len = len >>> 0x01
m = lo + half_len
if lt(o, v[m], x)
if lt_prepared_2(o, v[m], x_p)
lo = m + 1
len -= half_len + 1
else
Expand All @@ -199,9 +204,10 @@ function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keyt
u = T(1)
lo = lo - u
hi = hi + u
x_p = prepare(o, x)
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, x, v[m])
if lt_prepared_1(o, x_p, v[m])
hi = m
else
lo = m
Expand All @@ -217,13 +223,15 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRa
u = T(1)
lo = ilo - u
hi = ihi + u
x_p = prepare(o, x)
@inbounds while lo < hi - u
m = midpoint(lo, hi)
if lt(o, v[m], x)
if lt_prepared_2(o, v[m], x_p)
lo = m
elseif lt(o, x, v[m])
elseif lt_prepared_1(o, x_p, v[m])
hi = m
else
# TODO for further optimization: perform recursive calls with prepared inputs
a = searchsortedfirst(v, x, max(lo,ilo), m, o)
b = searchsortedlast(v, x, m, min(hi,ihi), o)
return a : b
Expand Down Expand Up @@ -820,9 +828,10 @@ function _sort!(v::AbstractVector, ::InsertionSortAlg, o::Ordering, kw)
@inbounds for i = lo_plus_1:hi
j = i
x = v[i]
x_p = prepare(o, x)
while j > lo
y = v[j-1]
if !(lt(o, x, y)::Bool)
if !(lt_prepared_1(o, x_p, y)::Bool)
break
end
v[j] = y
Expand Down Expand Up @@ -1074,16 +1083,17 @@ function partition!(t::AbstractVector, lo::Integer, hi::Integer, offset::Integer
pivot_index = mod(hash(lo), lo:hi)
@inbounds begin
pivot = v[pivot_index]
pivot_p = prepare(o, pivot)
while lo < pivot_index
x = v[lo]
fx = rev ? !lt(o, x, pivot) : lt(o, pivot, x)
fx = rev ? !lt_prepared_2(o, x, pivot_p) : lt_prepared_1(o, pivot_p, x)
t[(fx ? hi : lo) - offset] = x
offset += fx
lo += 1
end
while lo < hi
x = v[lo+1]
fx = rev ? lt(o, pivot, x) : !lt(o, x, pivot)
fx = rev ? lt_prepared_1(o, pivot_p, x) : !lt_prepared_2(o, x, pivot_p)
t[(fx ? hi : lo) - offset] = x
offset += fx
lo += 1
Expand Down Expand Up @@ -1425,6 +1435,7 @@ end
maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt
maybe_unsigned(x::BitSigned) = unsigned(x)
function _issorted(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
# TODO: replace this with `issorted(view(v, lo:hi), order=o)` once views are fast.
@boundscheck checkbounds(v, lo:hi)
@inbounds for i in (lo+1):hi
lt(o, v[i], v[i-1]) && return false
Expand Down Expand Up @@ -2336,12 +2347,13 @@ end

function partition!(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
pivot = selectpivot!(v, lo, hi, o)
pivot_p = prepare(o, pivot)
# pivot == v[lo], v[hi] > pivot
i, j = lo, hi
@inbounds while true
i += 1; j -= 1
while lt(o, v[i], pivot); i += 1; end;
while lt(o, pivot, v[j]); j -= 1; end;
while lt_prepared_2(o, v[i], pivot_p); i += 1; end;
while lt_prepared_1(o, pivot_p, v[j]); j -= 1; end;
i >= j && break
v[i], v[j] = v[j], v[i]
end
Expand Down
44 changes: 44 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,50 @@ end
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
end

@testset "Performance (how many timems By is called)" begin
# Intentional regressions are acceptable, accedental regressions are not.
cnt = Ref(0)
incr_identity = x -> (cnt[] += 1; x)
x = 1:50

cnt[] = 0
@test issorted(x; by=incr_identity)
@test cnt[] == 50 # Any less would be buggy.

cnt[] = 0
@test !issorted(x; by=incr_identity, rev=true)
@test cnt[] == 2 # Any less would be buggy.

cnt[] = 0
@test searchsortedfirst(x, 1; by=incr_identity) == 1
@test cnt[] <= 7

cnt[] = 0
@test searchsorted(repeat(1:10, inner=10), 3; by=incr_identity) == 21:30
@test cnt[] <= 16

cnt[] = 0
@test sort(x; by=incr_identity) == x
@test cnt[] <= 98

cnt[] = 0
@test sort(1:1000; by=incr_identity) == 1:1000
@test cnt[] <= 1998

cnt[] = 0
Random.seed!(1729)
x = randperm(1000)
@test sort!(x; by=incr_identity) == 1:1000
# This should succeed at least 99.99% of the time on random inputs
# and therefore should not be broken by changes to the rng
@test cnt[] <= 17203

cnt[] = 0
x = hash.(1:1000)
@test sort(x; by=incr_identity) == sort(x)
@test cnt[] <= 12999
end

@testset "partialsort tests added for BracketedSort #52006" begin
x = rand(Int, 1000)
@test partialsort(x, 1) == minimum(x)
Expand Down