Skip to content

Commit

Permalink
Remove compilation tricks and put upper limit on chunk_size
Browse files Browse the repository at this point in the history
I couldn't verify performance gains from making CHUNK_SIZE and MASK known at compile time, so I'm trying to resist the urge to abuse the compiler.
I also noticed this was 20% slower than SortingAlgorithms.jl's RadixSort for `@belapsed sort!(x) setup=(x=rand(Int, 3000000)) evals=1` because it used a chunk size of 13 which is too high. On my computer the best chunk size for that case is 10, and I couldn't find a size where higher than 10 was better than 10. Once I set max chunk size to 10, the 20% regression turned into a 2.5% regression, well within margin of error.
  • Loading branch information
LilithHafner committed Mar 14, 2022
1 parent 5eca90a commit 17f45e8
Showing 1 changed file with 26 additions and 43 deletions.
69 changes: 26 additions & 43 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,12 @@ end

# This is a stable least significant bit first radix sort.
#
# That is, it first sorts the entire vector by the last CHUNK_SIZE bits, then by the second
# to last CHUNK_SIZE bits, and so on. Stability means that it will not reorder two elements
# That is, it first sorts the entire vector by the last chunk_size bits, then by the second
# to last chunk_size bits, and so on. Stability means that it will not reorder two elements
# that compare equal. This is essential so that the order introduced by earlier,
# less significant passes is preserved by later passes.
#
# Each pass divides the input into 2^CHUNK_SIZE == MASK+1 buckets. To do this, it
# Each pass divides the input into 2^chunk_size == mask+1 buckets. To do this, it
# * counts the number of entries that fall into each bucket
# * uses those counts to compute the indices to move elements of those buckets into
# * moves elements into the computed indices in the swap array
Expand All @@ -685,32 +685,32 @@ end
# In the case of an odd number of passes, the returned vector will === the input vector t,
# not v. This is one of the many reasons radix_sort! is not exported.
function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned,
::Val{CHUNK_SIZE}, t::AbstractVector{U}) where {U <: Unsigned, CHUNK_SIZE}
# bits is unsigned and CHUNK_SIZE is a compile time constant for performance reasons.
MASK = UInt(1) << CHUNK_SIZE - 0x1
counts = Vector{UInt}(undef, MASK+2)
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
# bits is unsigned for performance reasons.
mask = UInt(1) << chunk_size - 0x1
counts = Vector{UInt}(undef, mask+2)

@inbounds for shift in 0:CHUNK_SIZE:bits-1
@inbounds for shift in 0:chunk_size:bits-1

# counts[2:MASK+2] will store the number of elements that fall into each bucket.
# if CHUNK_SIZE = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
# counts[2:mask+2] will store the number of elements that fall into each bucket.
# if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
counts .= 0
for k in lo:hi
x = v[k] # lookup the element
i = (x >> shift)&MASK + 2 # compute its bucket's index for this pass
i = (x >> shift)&mask + 2 # compute its bucket's index for this pass
counts[i] += 1 # increment that bucket's count
end

counts[1] = lo # set target index for the first bucket
cumsum!(counts, counts) # set target indices for subsequent buckets
# counts[1:MASK+1] now stores indices where the first member of each bucket
# counts[1:mask+1] now stores indices where the first member of each bucket
# belongs, not the number of elements in each bucket. We will put the first element
# of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
# and the last element of bucket 0x00 in t[counts[2]-1].

for k in lo:hi
x = v[k] # lookup the element
i = (x >> shift)&MASK + 1 # compute its bucket's index for this pass
i = (x >> shift)&mask + 1 # compute its bucket's index for this pass
j = counts[i] # lookup the target index
t[j] = x # put the element where it belongs
counts[i] = j + 1 # increment the target index for the next
Expand All @@ -722,6 +722,18 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig

v
end
function radix_chunk_size_heuristic(lo::Integer, hi::Integer, bits::Unsigned)
# chunk_size is the number of bits to radix over at once.
# We need to allocate an array of size 2^chunk size, and on the other hand the higher
# the chunk size the fewer passes we need. Theoretically, chunk size should be based on
# the Lambert W function applied to length. Empirically, we use this heuristic:
guess = min(10, log(maybe_unsigned(hi-lo))*3/4+3)
# TODO the maximum chunk size should be based on archetecture cache size.

# We need iterations * chunk size ≥ bits, and these cld's
# make an effort to get iterations * chunk size ≈ bits
UInt8(cld(bits, cld(bits, guess)))
end

# For AbstractVector{Bool}, counting sort is always best.
# This is an implementation of counting sort specialized for Bools.
Expand Down Expand Up @@ -832,36 +844,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
u[i] -= u_min
end

# chunk_size is the number of bits to radix over at once.
# We need to allocate an array of size 2^chunk size, and on the other hand the higher
# the chunk size the fewer passes we need. Theoretically, chunk size should be based on
# the Lambert W function applied to length. Empirically, we use this heuristic:
guess = log(lenm1)*3/4+3
# We need iterations * chunk size ≥ bits, and these cld's
# make an effort to get iterations * chunk size ≈ bits
chunk_size = UInt8(cld(bits, cld(bits, guess)))
@assert chunk_size >= 3

t = similar(u)
# This if else chain is to avoid dynamic dispatch for small cases.
# Chunk sizes less than 3 should never occur, and chunk sizes greater than 8
# only occur for arrays of length greater than 950, and tend to occur only for arrays
# of length greater than about 4000 where a single dynamic dispatch is less costly
u2 = if chunk_size == 3
radix_sort!(u, lo, hi, bits, Val(0x3), t)
elseif chunk_size == 4
radix_sort!(u, lo, hi, bits, Val(0x4), t)
elseif chunk_size == 5
radix_sort!(u, lo, hi, bits, Val(0x5), t)
elseif chunk_size == 6 # 9% to 15% savings over dynamic dispatch
radix_sort!(u, lo, hi, bits, Val(0x6), t)
elseif chunk_size == 7 # 2% to 7% savings over dynamic dispatch
radix_sort!(u, lo, hi, bits, Val(0x7), t)
elseif chunk_size == 8 # -1% to 10% savings and common for lengths between 300 and 3000
radix_sort!(u, lo, hi, bits, Val(0x8), t)
else
radix_sort!(u, lo, hi, bits, Val(chunk_size), t) # dynamic dispatch
end
u2 = radix_sort!(u, lo, hi, bits, similar(u))
Serial.deserialize!(v, u2, lo, hi, o, u_min)
end

Expand Down

0 comments on commit 17f45e8

Please sign in to comment.