Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BracketedSort a new, faster algorithm for partialsort and friends #52006

Merged
merged 41 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
9f4ad8b
add sample implementation
Oct 28, 2023
e47321f
add fallback and remove instrumentation
Oct 28, 2023
c61e6c3
add a faster, non-allocating version
Oct 28, 2023
9093cb9
small tweaks
Oct 29, 2023
ca7bd59
add tests and support target ranges
Oct 29, 2023
a170fc7
Add tuning
Oct 29, 2023
8c6eff6
implement threshold
Oct 29, 2023
4a6fce7
Merge branch 'master' into lh/fast-partialsort
Nov 2, 2023
1f1fc3b
add slow version to Julia
Nov 2, 2023
4aad01f
fix some bugs and fiddle with optimization passess (specifically disa…
Nov 2, 2023
8e66d4b
a bit more fiddling. The remaining perforamnce gap is due to NaN safety
Nov 2, 2023
5d07ffe
revert whitespace change
Nov 2, 2023
0f81beb
update comments and increase tries from 4 to 5
Nov 2, 2023
e1df36e
remove 'deleteme' development file
Nov 2, 2023
0ebef7e
Merge branch 'master' into lh/fast-partialsort
Nov 2, 2023
8003a0c
update docstring
Nov 2, 2023
8e933c3
support non-unit-range targets
Nov 3, 2023
76d2833
bugfix TODO: add tests that catch this
Nov 3, 2023
847172e
another bugfix (this one caught by CI)
Nov 3, 2023
5a85c03
update invalid lt tests
Nov 3, 2023
86fc129
add todo
Nov 3, 2023
a3a6c47
Tweak dispatch to avoid >100% regressions on 39 element arrays & opti…
Nov 3, 2023
bda1b6d
more performance characteristic tweaks (and a dynamic dispatch perfor…
Nov 3, 2023
b2e4529
use standard optimizations for recursive calls
Nov 3, 2023
fd8d967
cleanup, add comments, and admit weakness against inputs with duplica…
Nov 3, 2023
8361184
make lots of duplicates non-pathological (still not great, but not te…
Nov 3, 2023
5d52194
fix some bugs (wow, we need better test coverage!) and add a dispatch…
Nov 3, 2023
6f8048f
change offset from .5 to .7 (helps a huge amount for small to medium …
Nov 3, 2023
d0a38a2
noting that running a hundred benchmarks doesn't fail a single trial,…
Nov 3, 2023
52a6785
implement NFC todo that requires rebuilding Julia
Nov 3, 2023
1d90487
fix typo
Nov 3, 2023
83c9e27
check and document the invariant that makes the `@inbounds`s safe
Nov 4, 2023
0b2b399
fix some unimportant off by one errors that have been bugging me
Nov 4, 2023
422a14b
round less coarsely
Nov 4, 2023
ad82125
micro-refactor to use more code sharing
Nov 4, 2023
ccb5c99
Avoid overflow and nfc refactor add comments, and variable rename for…
Nov 4, 2023
650c6a2
randomize initial hash seed; use consistent recursive algorithms; add…
Nov 4, 2023
5c18e25
REVERT ME: revert the re-introduction of PartialQuickSort
Nov 4, 2023
5657e5f
implement Oscar's suggestion to speed up heuristic computation
Nov 5, 2023
069c453
accept that invalid lt continues to work
Nov 6, 2023
eb86ec5
Merge branch 'master' into lh/fast-partialsort
LilithHafner Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 198 additions & 1 deletion base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ issorted(itr;
issorted(itr, ord(lt,by,rev,order))

function partialsort!(v::AbstractVector, k::Union{Integer,OrdinalRange}, o::Ordering)
_sort!(v, InitialOptimizations(ScratchQuickSort(k)), o, (;))
# TODO move k from `alg` to `kw`
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a TODO that predates this PR, I'm just adding a note because this PR touches target range handling and reminded me that this is not the best approach.

# Don't perform InitialOptimizations before Bracketing. The optimizations take O(n)
# time and so does the whole sort. But do perform them before recursive calls because
# that can cause significant speedups when the target range is large so the runtime is
# dominated by k log k and the optimizations runs in O(k) time.
_sort!(v, BoolOptimization(
Small{12}( # Very small inputs should go straight to insertion sort
BracketedSort(k))),
o, (;))
maybeview(v, k)
end

Expand Down Expand Up @@ -1138,6 +1146,195 @@ function _sort!(v::AbstractVector, a::ScratchQuickSort, o::Ordering, kw;
end


"""
BracketedSort(target[, next::Algorithm]) <: Algorithm

Perform a partialsort for the elements that fall into the indices specified by the `target`
using BracketedSort with the `next` algorithm for subproblems.

BracketedSort takes a random* sample of the input, estimates the quantiles of the input
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
using the quantiles of the sample to find signposts that almost certainly bracket the target
values, filters the value in the input that fall between the signpost values to the front of
the input, and then, if that "almost certainly" turned out to be true, finds the target
within the small chunk that are, by value, between the signposts and now by position, at the
front of the vector. On small inputs or when target is close to the size of the input,
BracketedSort falls back to the `next` algorithm directly. Otherwise, BracketedSort uses the
`next` algorithm only to compute quantiles of the sample and to find the target within the
small chunk.

## Performance

If the `next` algorithm has `O(n * log(n))` runtime and the input is not pathological then
the runtime of this algorithm is `O(n + k * log(k))` where `n` is the length of the input
and `k` is `length(target)`. On pathological inputs the asymptotic runtime is the same as
the runtime of the `next` algorithm.

BracketedSort itself does not allocate. If `next` is in-place then BracketedSort is also
in-place. If `next` is not in place, and it's space usage increases monotonically with input
length then BracketedSort's maximum space usage will never be more than the space usage
of `next` on the input BracketedSort receives. For large nonpathological inputs and targets
substantially smaller than the size of the input, BracketedSort's maximum memory usage will
be much less than `next`'s. If the maximum additional space usage of `next` scales linearly
then for small k the average* maximum additional space usage of BracketedSort will be
`O(n^(2.3/3))`.

By default, BracketedSort uses the `O(n)` space and `O(n + k log k)` runtime
`ScratchQuickSort` algorithm recursively.

*Sorting is unable to depend on Random.jl because Random.jl depends on sorting.
Consequently, we use `hash` as a source of randomness. The average runtime guarantees
assume that `hash(x::Int)` produces a random result. However, as this randomization is
deterministic, if you try hard enough you can find inputs that consistently reach the
worst case bounds. Actually constructing such inputs is an exercise left to the reader.
Have fun :).

Characteristics:
* *unstable*: does not preserve the ordering of elements that compare equal
(e.g. "a" and "A" in a sort of letters that ignores case).
* *in-place* in memory if the `next` algorithm is in-place.
* *estimate-and-filter*: strategy
* *linear runtime* if `length(target)` is constant and `next` is reasonable
* *n + k log k* worst case runtime if `next` has that runtime.
* *pathological inputs* can significantly increase constant factors.
"""
struct BracketedSort{T, F} <: Algorithm
target::T
get_next::F
end

# TODO: this composition between BracketedSort and ScratchQuickSort does not bring me joy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be avoided via moving k from alg to kw.

BracketedSort(k) = BracketedSort(k, k -> InitialOptimizations(ScratchQuickSort(k)))

function bracket_kernel!(v::AbstractVector, lo, hi, lo_signpost, hi_signpost, o)
i = 0
count_below = 0
checkbounds(v, lo:hi)
for j in lo:hi
x = @inbounds v[j]
a = lo_signpost !== nothing && lt(o, x, lo_signpost)
b = hi_signpost === nothing || !lt(o, hi_signpost, x)
count_below += a
# if a != b # This branch is almost never taken, so making it branchless is bad.
# @inbounds v[i], v[j] = v[j], v[i]
# i += 1
# end
c = a != b # JK, this is faster.
k = i * c + j
# Invariant: @assert firstindex(v) ≤ lo ≤ i + j ≤ k ≤ j ≤ hi ≤ lastindex(v)
@inbounds v[j], v[k] = v[k], v[j]
i += c - 1
end
count_below, i+hi
end

function move!(v, target, source)
# This function never dominates runtime—only add `@inbounds` if you can demonstrate a
# performance improvement. And if you do, also double check behavior when `target`
# is out of bounds.
@assert length(target) == length(source)
if length(target) == 1 || isdisjoint(target, source)
for (i, j) in zip(target, source)
v[i], v[j] = v[j], v[i]
end
else
@assert minimum(source) <= minimum(target)
reverse!(v, minimum(source), maximum(target))
reverse!(v, minimum(target), maximum(target))
end
end

function _sort!(v::AbstractVector, a::BracketedSort, o::Ordering, kw)
@getkw lo hi scratch
# TODO for further optimization: reuse scratch between trials better, from signpost
# selection to recursive calls, and from the fallback (but be aware of type stability,
# especially when sorting IEEE floats.

# We don't need to bounds check target because that is done higher up in the stack
# However, we cannot assume the target is inbounds.
lo < hi || return scratch
ln = hi - lo + 1

# This is simply a precomputed short-circuit to avoid doing scalar math for small inputs.
# It does not change dispatch at all.
ln < 260 && return _sort!(v, a.get_next(a.target), o, kw)

target = a.target
k = cbrt(ln)
k2 = round(Int, k^2)
k2ln = k2/ln
offset = .15k2*top_set_bit(k2) # TODO for further optimization: tune this
lo_signpost_i, hi_signpost_i =
(floor(Int, (tar - lo) * k2ln + lo + off) for (tar, off) in
((minimum(target), -offset), (maximum(target), offset)))
lastindex_sample = lo+k2-1
expected_middle_ln = (min(lastindex_sample, hi_signpost_i) - max(lo, lo_signpost_i) + 1) / k2ln
# This heuristic is complicated because it fairly accurately reflects the runtime of
# this algorithm which is necessary to get good dispatch when both the target is large
# and the input are large.
# expected_middle_ln is a float and k2 is significantly below typemax(Int), so this will
# not overflow:
# TODO move target from alg to kw to avoid this ickyness:
ln <= 130 + 2k2 + 2expected_middle_ln && return _sort!(v, a.get_next(a.target), o, kw)

# We store the random sample in
# sample = view(v, lo:lo+k2)
# but views are not quite as fast as using the input array directly,
# so we don't actually construct this view at runtime.

# TODO for further optimization: handle lots of duplicates better.
# Right now lots of duplicates rounds up when it could use some super fast optimizations
# in some cases.
# e.g.
#
# Target: |----|
# Sorted input: 000000000000000000011111112222223333333333
#
# Will filter all zeros and ones to the front when it could just take the first few
# it encounters. This optimization would be especially potent when `allequal(ans)` and
# equal elements are egal.

# 3 random trials should typically give us 0.99999 reliability; we can assume
# the input is pathological and abort to fallback if we fail three trials.
seed = hash(ln, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
for attempt in 1:3
seed = hash(attempt, seed)
for i in lo:lo+k2-1
j = mod(hash(i, seed), i:hi) # TODO for further optimization: be sneaky and remove this division
v[i], v[j] = v[j], v[i]
end
count_below, lastindex_middle = if lo_signpost_i <= lo && lastindex_sample <= hi_signpost_i
# The heuristics higher up in this function that dispatch to the `next`
# algorithm should prevent this from happening.
# Specifically, this means that expected_middle_ln == ln, so
# ln <= ... + 2.0expected_middle_ln && return ...
# will trigger.
@assert false
# But if it does happen, the kernel reduces to
0, hi
elseif lo_signpost_i <= lo
_sort!(v, a.get_next(hi_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, nothing, v[hi_signpost_i], o)
elseif lastindex_sample <= hi_signpost_i
_sort!(v, a.get_next(lo_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, v[lo_signpost_i], nothing, o)
else
# TODO for further optimization: don't sort the middle elements
_sort!(v, a.get_next(lo_signpost_i:hi_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, v[lo_signpost_i], v[hi_signpost_i], o)
end
target_in_middle = target .- count_below
if lo <= minimum(target_in_middle) && maximum(target_in_middle) <= lastindex_middle
scratch = _sort!(v, a.get_next(target_in_middle), o, (;kw..., hi=lastindex_middle))
move!(v, target, target_in_middle)
return scratch
end
# This line almost never runs.
end
# This line only runs on pathological inputs. Make sure it's covered by tests :)
_sort!(v, a.get_next(target), o, kw)
end


"""
StableCheckSorted(next) <: Algorithm

Expand Down
37 changes: 37 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ end
for alg in safe_algs
@test sort(1:n, alg=alg, lt = (i,j) -> v[i]<=v[j]) == perm
end
# This could easily break with minor heuristic adjustments
# because partialsort is not even guaranteed to be stable:
@test partialsort(1:n, 172, lt = (i,j) -> v[i]<=v[j]) == perm[172]
@test partialsort(1:n, 315:415, lt = (i,j) -> v[i]<=v[j]) == perm[315:415]

Expand Down Expand Up @@ -1034,6 +1036,41 @@ end
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
end

@testset "partialsort tests added for BracketedSort #52006" begin
x = rand(Int, 1000)
@test partialsort(x, 1) == minimum(x)
@test partialsort(x, 1000) == maximum(x)
sx = sort(x)
for i in [1, 2, 4, 10, 11, 425, 500, 845, 991, 997, 999, 1000]
@test partialsort(x, i) == sx[i]
end
for i in [1:1, 1:2, 1:5, 1:8, 1:9, 1:11, 1:108, 135:812, 220:586, 363:368, 450:574, 458:597, 469:638, 487:488, 500:501, 584:594, 1000:1000]
@test partialsort(x, i) == sx[i]
end

# Semi-pathological input
seed = hash(1000, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
seed = hash(1, seed)
for i in 1:100
j = mod(hash(i, seed), i:1000)
x[j] = typemax(Int)
end
@test partialsort(x, 500) == sort(x)[500]

# Fully pathological input
# it would be too much trouble to actually construct a valid pathological input, so we
# construct an invalid pathological input.
# This test is kind of sketchy because it passes invalid inputs to the function
for i in [1:6, 1:483, 1:957, 77:86, 118:478, 223:227, 231:970, 317:958, 500:501, 500:501, 500:501, 614:620, 632:635, 658:665, 933:940, 937:942, 997:1000, 999:1000]
x = rand(1:5, 1000)
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
end
for i in [1, 7, 8, 490, 495, 852, 993, 996, 1000]
x = rand(1:5, 1000)
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
end
end

# This testset is at the end of the file because it is slow.
@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
Expand Down