Skip to content

Commit 17f45e8

Browse files
committed
Remove compilation tricks and put upper limit on chunk_size
I couldn't verify performance gains from making CHUNK_SIZE and MASK known at compile time, so I'm trying to resist the urge to abuse the compiler. I also noticed this was 20% slower than SortingAlgorithms.jl's RadixSort for `@belapsed sort!(x) setup=(x=rand(Int, 3000000)) evals=1` because it used a chunk size of 13 which is too high. On my computer the best chunk size for that case is 10, and I couldn't find a size where higher than 10 was better than 10. Once I set max chunk size to 10, the 20% regression turned into a 2.5% regression, well within margin of error.
1 parent 5eca90a commit 17f45e8

File tree

1 file changed

+26
-43
lines changed

1 file changed

+26
-43
lines changed

base/sort.jl

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -671,12 +671,12 @@ end
671671

672672
# This is a stable least significant bit first radix sort.
673673
#
674-
# That is, it first sorts the entire vector by the last CHUNK_SIZE bits, then by the second
675-
# to last CHUNK_SIZE bits, and so on. Stability means that it will not reorder two elements
674+
# That is, it first sorts the entire vector by the last chunk_size bits, then by the second
675+
# to last chunk_size bits, and so on. Stability means that it will not reorder two elements
676676
# that compare equal. This is essential so that the order introduced by earlier,
677677
# less significant passes is preserved by later passes.
678678
#
679-
# Each pass divides the input into 2^CHUNK_SIZE == MASK+1 buckets. To do this, it
679+
# Each pass divides the input into 2^chunk_size == mask+1 buckets. To do this, it
680680
# * counts the number of entries that fall into each bucket
681681
# * uses those counts to compute the indices to move elements of those buckets into
682682
# * moves elements into the computed indices in the swap array
@@ -685,32 +685,32 @@ end
685685
# In the case of an odd number of passes, the returned vector will === the input vector t,
686686
# not v. This is one of the many reasons radix_sort! is not exported.
687687
function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned,
688-
::Val{CHUNK_SIZE}, t::AbstractVector{U}) where {U <: Unsigned, CHUNK_SIZE}
689-
# bits is unsigned and CHUNK_SIZE is a compile time constant for performance reasons.
690-
MASK = UInt(1) << CHUNK_SIZE - 0x1
691-
counts = Vector{UInt}(undef, MASK+2)
688+
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
689+
# bits is unsigned for performance reasons.
690+
mask = UInt(1) << chunk_size - 0x1
691+
counts = Vector{UInt}(undef, mask+2)
692692

693-
@inbounds for shift in 0:CHUNK_SIZE:bits-1
693+
@inbounds for shift in 0:chunk_size:bits-1
694694

695-
# counts[2:MASK+2] will store the number of elements that fall into each bucket.
696-
# if CHUNK_SIZE = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
695+
# counts[2:mask+2] will store the number of elements that fall into each bucket.
696+
# if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
697697
counts .= 0
698698
for k in lo:hi
699699
x = v[k] # lookup the element
700-
i = (x >> shift)&MASK + 2 # compute its bucket's index for this pass
700+
i = (x >> shift)&mask + 2 # compute its bucket's index for this pass
701701
counts[i] += 1 # increment that bucket's count
702702
end
703703

704704
counts[1] = lo # set target index for the first bucket
705705
cumsum!(counts, counts) # set target indices for subsequent buckets
706-
# counts[1:MASK+1] now stores indices where the first member of each bucket
706+
# counts[1:mask+1] now stores indices where the first member of each bucket
707707
# belongs, not the number of elements in each bucket. We will put the first element
708708
# of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
709709
# and the last element of bucket 0x00 in t[counts[2]-1].
710710

711711
for k in lo:hi
712712
x = v[k] # lookup the element
713-
i = (x >> shift)&MASK + 1 # compute its bucket's index for this pass
713+
i = (x >> shift)&mask + 1 # compute its bucket's index for this pass
714714
j = counts[i] # lookup the target index
715715
t[j] = x # put the element where it belongs
716716
counts[i] = j + 1 # increment the target index for the next
@@ -722,6 +722,18 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig
722722

723723
v
724724
end
725+
function radix_chunk_size_heuristic(lo::Integer, hi::Integer, bits::Unsigned)
726+
# chunk_size is the number of bits to radix over at once.
727+
# We need to allocate an array of size 2^chunk size, and on the other hand the higher
728+
# the chunk size the fewer passes we need. Theoretically, chunk size should be based on
729+
# the Lambert W function applied to length. Empirically, we use this heuristic:
730+
guess = min(10, log(maybe_unsigned(hi-lo))*3/4+3)
731+
# TODO the maximum chunk size should be based on archetecture cache size.
732+
733+
# We need iterations * chunk size ≥ bits, and these cld's
734+
# make an effort to get iterations * chunk size ≈ bits
735+
UInt8(cld(bits, cld(bits, guess)))
736+
end
725737

726738
# For AbstractVector{Bool}, counting sort is always best.
727739
# This is an implementation of counting sort specialized for Bools.
@@ -832,36 +844,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
832844
u[i] -= u_min
833845
end
834846

835-
# chunk_size is the number of bits to radix over at once.
836-
# We need to allocate an array of size 2^chunk size, and on the other hand the higher
837-
# the chunk size the fewer passes we need. Theoretically, chunk size should be based on
838-
# the Lambert W function applied to length. Empirically, we use this heuristic:
839-
guess = log(lenm1)*3/4+3
840-
# We need iterations * chunk size ≥ bits, and these cld's
841-
# make an effort to get iterations * chunk size ≈ bits
842-
chunk_size = UInt8(cld(bits, cld(bits, guess)))
843-
@assert chunk_size >= 3
844-
845-
t = similar(u)
846-
# This if else chain is to avoid dynamic dispatch for small cases.
847-
# Chunk sizes less than 3 should never occur, and chunk sizes greater than 8
848-
# only occur for arrays of length greater than 950, and tend to occur only for arrays
849-
# of length greater than about 4000 where a single dynamic dispatch is less costly
850-
u2 = if chunk_size == 3
851-
radix_sort!(u, lo, hi, bits, Val(0x3), t)
852-
elseif chunk_size == 4
853-
radix_sort!(u, lo, hi, bits, Val(0x4), t)
854-
elseif chunk_size == 5
855-
radix_sort!(u, lo, hi, bits, Val(0x5), t)
856-
elseif chunk_size == 6 # 9% to 15% savings over dynamic dispatch
857-
radix_sort!(u, lo, hi, bits, Val(0x6), t)
858-
elseif chunk_size == 7 # 2% to 7% savings over dynamic dispatch
859-
radix_sort!(u, lo, hi, bits, Val(0x7), t)
860-
elseif chunk_size == 8 # -1% to 10% savings and common for lengths between 300 and 3000
861-
radix_sort!(u, lo, hi, bits, Val(0x8), t)
862-
else
863-
radix_sort!(u, lo, hi, bits, Val(chunk_size), t) # dynamic dispatch
864-
end
847+
u2 = radix_sort!(u, lo, hi, bits, similar(u))
865848
Serial.deserialize!(v, u2, lo, hi, o, u_min)
866849
end
867850

0 commit comments

Comments
 (0)