Skip to content

Commit 187e8c2

Browse files
authored
Add BracketedSort a new, faster algorithm for partialsort and friends (#52006)
1 parent 79a845c commit 187e8c2

File tree

2 files changed

+235
-1
lines changed

2 files changed

+235
-1
lines changed

base/sort.jl

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,15 @@ issorted(itr;
9090
issorted(itr, ord(lt,by,rev,order))
9191

9292
function partialsort!(v::AbstractVector, k::Union{Integer,OrdinalRange}, o::Ordering)
93-
_sort!(v, InitialOptimizations(ScratchQuickSort(k)), o, (;))
93+
# TODO move k from `alg` to `kw`
94+
# Don't perform InitialOptimizations before Bracketing. The optimizations take O(n)
95+
# time and so does the whole sort. But do perform them before recursive calls because
96+
# that can cause significant speedups when the target range is large so the runtime is
97+
# dominated by k log k and the optimizations runs in O(k) time.
98+
_sort!(v, BoolOptimization(
99+
Small{12}( # Very small inputs should go straight to insertion sort
100+
BracketedSort(k))),
101+
o, (;))
94102
maybeview(v, k)
95103
end
96104

@@ -1138,6 +1146,195 @@ function _sort!(v::AbstractVector, a::ScratchQuickSort, o::Ordering, kw;
11381146
end
11391147

11401148

1149+
"""
1150+
BracketedSort(target[, next::Algorithm]) <: Algorithm
1151+
1152+
Perform a partialsort for the elements that fall into the indices specified by the `target`
1153+
using BracketedSort with the `next` algorithm for subproblems.
1154+
1155+
BracketedSort takes a random* sample of the input, estimates the quantiles of the input
1156+
using the quantiles of the sample to find signposts that almost certainly bracket the target
1157+
values, filters the value in the input that fall between the signpost values to the front of
1158+
the input, and then, if that "almost certainly" turned out to be true, finds the target
1159+
within the small chunk that are, by value, between the signposts and now by position, at the
1160+
front of the vector. On small inputs or when target is close to the size of the input,
1161+
BracketedSort falls back to the `next` algorithm directly. Otherwise, BracketedSort uses the
1162+
`next` algorithm only to compute quantiles of the sample and to find the target within the
1163+
small chunk.
1164+
1165+
## Performance
1166+
1167+
If the `next` algorithm has `O(n * log(n))` runtime and the input is not pathological then
1168+
the runtime of this algorithm is `O(n + k * log(k))` where `n` is the length of the input
1169+
and `k` is `length(target)`. On pathological inputs the asymptotic runtime is the same as
1170+
the runtime of the `next` algorithm.
1171+
1172+
BracketedSort itself does not allocate. If `next` is in-place then BracketedSort is also
1173+
in-place. If `next` is not in place, and it's space usage increases monotonically with input
1174+
length then BracketedSort's maximum space usage will never be more than the space usage
1175+
of `next` on the input BracketedSort receives. For large nonpathological inputs and targets
1176+
substantially smaller than the size of the input, BracketedSort's maximum memory usage will
1177+
be much less than `next`'s. If the maximum additional space usage of `next` scales linearly
1178+
then for small k the average* maximum additional space usage of BracketedSort will be
1179+
`O(n^(2.3/3))`.
1180+
1181+
By default, BracketedSort uses the `O(n)` space and `O(n + k log k)` runtime
1182+
`ScratchQuickSort` algorithm recursively.
1183+
1184+
*Sorting is unable to depend on Random.jl because Random.jl depends on sorting.
1185+
Consequently, we use `hash` as a source of randomness. The average runtime guarantees
1186+
assume that `hash(x::Int)` produces a random result. However, as this randomization is
1187+
deterministic, if you try hard enough you can find inputs that consistently reach the
1188+
worst case bounds. Actually constructing such inputs is an exercise left to the reader.
1189+
Have fun :).
1190+
1191+
Characteristics:
1192+
* *unstable*: does not preserve the ordering of elements that compare equal
1193+
(e.g. "a" and "A" in a sort of letters that ignores case).
1194+
* *in-place* in memory if the `next` algorithm is in-place.
1195+
* *estimate-and-filter*: strategy
1196+
* *linear runtime* if `length(target)` is constant and `next` is reasonable
1197+
* *n + k log k* worst case runtime if `next` has that runtime.
1198+
* *pathological inputs* can significantly increase constant factors.
1199+
"""
1200+
struct BracketedSort{T, F} <: Algorithm
1201+
target::T
1202+
get_next::F
1203+
end
1204+
1205+
# TODO: this composition between BracketedSort and ScratchQuickSort does not bring me joy
1206+
BracketedSort(k) = BracketedSort(k, k -> InitialOptimizations(ScratchQuickSort(k)))
1207+
1208+
function bracket_kernel!(v::AbstractVector, lo, hi, lo_signpost, hi_signpost, o)
1209+
i = 0
1210+
count_below = 0
1211+
checkbounds(v, lo:hi)
1212+
for j in lo:hi
1213+
x = @inbounds v[j]
1214+
a = lo_signpost !== nothing && lt(o, x, lo_signpost)
1215+
b = hi_signpost === nothing || !lt(o, hi_signpost, x)
1216+
count_below += a
1217+
# if a != b # This branch is almost never taken, so making it branchless is bad.
1218+
# @inbounds v[i], v[j] = v[j], v[i]
1219+
# i += 1
1220+
# end
1221+
c = a != b # JK, this is faster.
1222+
k = i * c + j
1223+
# Invariant: @assert firstindex(v) ≤ lo ≤ i + j ≤ k ≤ j ≤ hi ≤ lastindex(v)
1224+
@inbounds v[j], v[k] = v[k], v[j]
1225+
i += c - 1
1226+
end
1227+
count_below, i+hi
1228+
end
1229+
1230+
function move!(v, target, source)
1231+
# This function never dominates runtime—only add `@inbounds` if you can demonstrate a
1232+
# performance improvement. And if you do, also double check behavior when `target`
1233+
# is out of bounds.
1234+
@assert length(target) == length(source)
1235+
if length(target) == 1 || isdisjoint(target, source)
1236+
for (i, j) in zip(target, source)
1237+
v[i], v[j] = v[j], v[i]
1238+
end
1239+
else
1240+
@assert minimum(source) <= minimum(target)
1241+
reverse!(v, minimum(source), maximum(target))
1242+
reverse!(v, minimum(target), maximum(target))
1243+
end
1244+
end
1245+
1246+
function _sort!(v::AbstractVector, a::BracketedSort, o::Ordering, kw)
1247+
@getkw lo hi scratch
1248+
# TODO for further optimization: reuse scratch between trials better, from signpost
1249+
# selection to recursive calls, and from the fallback (but be aware of type stability,
1250+
# especially when sorting IEEE floats.
1251+
1252+
# We don't need to bounds check target because that is done higher up in the stack
1253+
# However, we cannot assume the target is inbounds.
1254+
lo < hi || return scratch
1255+
ln = hi - lo + 1
1256+
1257+
# This is simply a precomputed short-circuit to avoid doing scalar math for small inputs.
1258+
# It does not change dispatch at all.
1259+
ln < 260 && return _sort!(v, a.get_next(a.target), o, kw)
1260+
1261+
target = a.target
1262+
k = cbrt(ln)
1263+
k2 = round(Int, k^2)
1264+
k2ln = k2/ln
1265+
offset = .15k2*top_set_bit(k2) # TODO for further optimization: tune this
1266+
lo_signpost_i, hi_signpost_i =
1267+
(floor(Int, (tar - lo) * k2ln + lo + off) for (tar, off) in
1268+
((minimum(target), -offset), (maximum(target), offset)))
1269+
lastindex_sample = lo+k2-1
1270+
expected_middle_ln = (min(lastindex_sample, hi_signpost_i) - max(lo, lo_signpost_i) + 1) / k2ln
1271+
# This heuristic is complicated because it fairly accurately reflects the runtime of
1272+
# this algorithm which is necessary to get good dispatch when both the target is large
1273+
# and the input are large.
1274+
# expected_middle_ln is a float and k2 is significantly below typemax(Int), so this will
1275+
# not overflow:
1276+
# TODO move target from alg to kw to avoid this ickyness:
1277+
ln <= 130 + 2k2 + 2expected_middle_ln && return _sort!(v, a.get_next(a.target), o, kw)
1278+
1279+
# We store the random sample in
1280+
# sample = view(v, lo:lo+k2)
1281+
# but views are not quite as fast as using the input array directly,
1282+
# so we don't actually construct this view at runtime.
1283+
1284+
# TODO for further optimization: handle lots of duplicates better.
1285+
# Right now lots of duplicates rounds up when it could use some super fast optimizations
1286+
# in some cases.
1287+
# e.g.
1288+
#
1289+
# Target: |----|
1290+
# Sorted input: 000000000000000000011111112222223333333333
1291+
#
1292+
# Will filter all zeros and ones to the front when it could just take the first few
1293+
# it encounters. This optimization would be especially potent when `allequal(ans)` and
1294+
# equal elements are egal.
1295+
1296+
# 3 random trials should typically give us 0.99999 reliability; we can assume
1297+
# the input is pathological and abort to fallback if we fail three trials.
1298+
seed = hash(ln, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
1299+
for attempt in 1:3
1300+
seed = hash(attempt, seed)
1301+
for i in lo:lo+k2-1
1302+
j = mod(hash(i, seed), i:hi) # TODO for further optimization: be sneaky and remove this division
1303+
v[i], v[j] = v[j], v[i]
1304+
end
1305+
count_below, lastindex_middle = if lo_signpost_i <= lo && lastindex_sample <= hi_signpost_i
1306+
# The heuristics higher up in this function that dispatch to the `next`
1307+
# algorithm should prevent this from happening.
1308+
# Specifically, this means that expected_middle_ln == ln, so
1309+
# ln <= ... + 2.0expected_middle_ln && return ...
1310+
# will trigger.
1311+
@assert false
1312+
# But if it does happen, the kernel reduces to
1313+
0, hi
1314+
elseif lo_signpost_i <= lo
1315+
_sort!(v, a.get_next(hi_signpost_i), o, (;kw..., hi=lastindex_sample))
1316+
bracket_kernel!(v, lo, hi, nothing, v[hi_signpost_i], o)
1317+
elseif lastindex_sample <= hi_signpost_i
1318+
_sort!(v, a.get_next(lo_signpost_i), o, (;kw..., hi=lastindex_sample))
1319+
bracket_kernel!(v, lo, hi, v[lo_signpost_i], nothing, o)
1320+
else
1321+
# TODO for further optimization: don't sort the middle elements
1322+
_sort!(v, a.get_next(lo_signpost_i:hi_signpost_i), o, (;kw..., hi=lastindex_sample))
1323+
bracket_kernel!(v, lo, hi, v[lo_signpost_i], v[hi_signpost_i], o)
1324+
end
1325+
target_in_middle = target .- count_below
1326+
if lo <= minimum(target_in_middle) && maximum(target_in_middle) <= lastindex_middle
1327+
scratch = _sort!(v, a.get_next(target_in_middle), o, (;kw..., hi=lastindex_middle))
1328+
move!(v, target, target_in_middle)
1329+
return scratch
1330+
end
1331+
# This line almost never runs.
1332+
end
1333+
# This line only runs on pathological inputs. Make sure it's covered by tests :)
1334+
_sort!(v, a.get_next(target), o, kw)
1335+
end
1336+
1337+
11411338
"""
11421339
StableCheckSorted(next) <: Algorithm
11431340

test/sorting.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,8 @@ end
721721
for alg in safe_algs
722722
@test sort(1:n, alg=alg, lt = (i,j) -> v[i]<=v[j]) == perm
723723
end
724+
# This could easily break with minor heuristic adjustments
725+
# because partialsort is not even guaranteed to be stable:
724726
@test partialsort(1:n, 172, lt = (i,j) -> v[i]<=v[j]) == perm[172]
725727
@test partialsort(1:n, 315:415, lt = (i,j) -> v[i]<=v[j]) == perm[315:415]
726728

@@ -1034,6 +1036,41 @@ end
10341036
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
10351037
end
10361038

1039+
@testset "partialsort tests added for BracketedSort #52006" begin
1040+
x = rand(Int, 1000)
1041+
@test partialsort(x, 1) == minimum(x)
1042+
@test partialsort(x, 1000) == maximum(x)
1043+
sx = sort(x)
1044+
for i in [1, 2, 4, 10, 11, 425, 500, 845, 991, 997, 999, 1000]
1045+
@test partialsort(x, i) == sx[i]
1046+
end
1047+
for i in [1:1, 1:2, 1:5, 1:8, 1:9, 1:11, 1:108, 135:812, 220:586, 363:368, 450:574, 458:597, 469:638, 487:488, 500:501, 584:594, 1000:1000]
1048+
@test partialsort(x, i) == sx[i]
1049+
end
1050+
1051+
# Semi-pathological input
1052+
seed = hash(1000, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
1053+
seed = hash(1, seed)
1054+
for i in 1:100
1055+
j = mod(hash(i, seed), i:1000)
1056+
x[j] = typemax(Int)
1057+
end
1058+
@test partialsort(x, 500) == sort(x)[500]
1059+
1060+
# Fully pathological input
1061+
# it would be too much trouble to actually construct a valid pathological input, so we
1062+
# construct an invalid pathological input.
1063+
# This test is kind of sketchy because it passes invalid inputs to the function
1064+
for i in [1:6, 1:483, 1:957, 77:86, 118:478, 223:227, 231:970, 317:958, 500:501, 500:501, 500:501, 614:620, 632:635, 658:665, 933:940, 937:942, 997:1000, 999:1000]
1065+
x = rand(1:5, 1000)
1066+
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
1067+
end
1068+
for i in [1, 7, 8, 490, 495, 852, 993, 996, 1000]
1069+
x = rand(1:5, 1000)
1070+
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
1071+
end
1072+
end
1073+
10371074
# This testset is at the end of the file because it is slow.
10381075
@testset "searchsorted" begin
10391076
numTypes = [ Int8, Int16, Int32, Int64, Int128,

0 commit comments

Comments
 (0)