Skip to content

Commit a975bcf

Browse files
Lilith HafnerLilith Hafner
authored andcommitted
fix unexpected allocations in Radix Sort
fixes #47474 in this PR rather than separate to avoid dealing with the merge
1 parent 2da298f commit a975bcf

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

base/sort.jl

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -864,17 +864,27 @@ function _sort!(v::AbstractVector, a::RadixSort, o::DirectOrdering, kw)
864864

865865
len = hi-lo + 1
866866
U = UIntMappable(eltype(v), o)
867+
# A large if-else chain to avoid type instabilities and dynamic dispatch
867868
if scratch !== nothing && checkbounds(Bool, scratch, lo:hi) # Fully preallocated and aligned scratch
868-
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, scratch))
869-
uint_unmap!(v, u2, lo, hi, o, umn)
869+
if radix_sort!(u, lo, hi, bits, reinterpret(U, scratch))
870+
uint_unmap!(v, u, lo, hi, o, umn)
871+
else
872+
uint_unmap!(v, t, lo, hi, o, umn)
873+
end
870874
elseif scratch !== nothing && (applicable(resize!, scratch, len) || length(scratch) >= len) # Viable scratch
871875
length(scratch) >= len || resize!(scratch, len)
872876
t1 = axes(scratch, 1) isa OneTo ? scratch : view(scratch, firstindex(scratch):lastindex(scratch))
873-
u2 = radix_sort!(view(u, lo:hi), 1, len, bits, reinterpret(U, t1))
874-
uint_unmap!(view(v, lo:hi), u2, 1, len, o, umn)
877+
if radix_sort!(view(u, lo:hi), 1, len, bits, reinterpret(U, t1))
878+
uint_unmap!(view(v, lo:hi), u, 1, len, o, umn)
879+
else
880+
uint_unmap!(view(v, lo:hi), t, 1, len, o, umn)
881+
end
875882
else # No viable scratch
876-
u2 = radix_sort!(u, lo, hi, bits, similar(u))
877-
uint_unmap!(v, u2, lo, hi, o, umn)
883+
if radix_sort!(u, lo, hi, bits, similar(u))
884+
uint_unmap!(v, u, lo, hi, o, umn)
885+
else
886+
uint_unmap!(v, t, lo, hi, o, umn)
887+
end
878888
end
879889
end
880890

@@ -1025,16 +1035,28 @@ function _sort!(v::AbstractVector, a::StableCheckSorted, o::Ordering, kw)
10251035
end
10261036

10271037

1028-
# In the case of an odd number of passes, the returned vector will === the input vector t,
1029-
# not v. This is one of the many reasons radix_sort! is not exported.
1038+
# The return value indicates whether v is sorted (true) or t is sorted (false)
1039+
# This is one of the many reasons radix_sort! is not exported.
10301040
function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned,
10311041
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
10321042
# bits is unsigned for performance reasons.
1033-
mask = UInt(1) << chunk_size - 1
1034-
counts = Vector{Int}(undef, mask+2)
1035-
1036-
@inbounds for shift in 0:chunk_size:bits-1
1037-
1043+
counts = Vector{Int}(undef, 1 << chunk_size + 1)
1044+
1045+
shift = 0
1046+
while true
1047+
@noinline radix_sort_pass!(t, lo, hi, counts, v, shift, chunk_size)
1048+
# the latest data resides in t
1049+
shift += chunk_size
1050+
shift < bits || return false
1051+
@noinline radix_sort_pass!(v, lo, hi, counts, t, shift, chunk_size)
1052+
# the latest data resides in v
1053+
shift += chunk_size
1054+
shift < bits || return true
1055+
end
1056+
end
1057+
function radix_sort_pass!(t, lo, hi, counts, v, shift, chunk_size)
1058+
mask = UInt(1) << chunk_size - 1 # mask is defined in pass so that the compiler
1059+
@inbounds begin # ↳ knows it's shape
10381060
# counts[2:mask+2] will store the number of elements that fall into each bucket.
10391061
# if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
10401062
counts .= 0
@@ -1058,12 +1080,7 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig
10581080
t[j] = x # put the element where it belongs
10591081
counts[i] = j + 1 # increment the target index for the next
10601082
end # ↳ element in this bucket
1061-
1062-
v, t = t, v # swap the now sorted destination vector t back into primary vector v
1063-
10641083
end
1065-
1066-
v
10671084
end
10681085
function radix_chunk_size_heuristic(lo::Integer, hi::Integer, bits::Unsigned)
10691086
# chunk_size is the number of bits to radix over at once.

0 commit comments

Comments
 (0)