Skip to content

Commit

Permalink
[libc] Improve qsort (#120450)
Browse files Browse the repository at this point in the history
  • Loading branch information
Voultapher authored Dec 29, 2024
1 parent 7f3428d commit d2e71c9
Show file tree
Hide file tree
Showing 14 changed files with 539 additions and 291 deletions.
12 changes: 6 additions & 6 deletions libc/src/stdlib/heap_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ namespace internal {
// A simple in-place heapsort implementation.
// Follow the implementation in https://en.wikipedia.org/wiki/Heapsort.

LIBC_INLINE void heap_sort(const Array &array) {
size_t end = array.size();
template <typename A, typename F>
LIBC_INLINE void heap_sort(const A &array, const F &is_less) {
size_t end = array.len();
size_t start = end / 2;

auto left_child = [](size_t i) -> size_t { return 2 * i + 1; };
const auto left_child = [](size_t i) -> size_t { return 2 * i + 1; };

while (end > 1) {
if (start > 0) {
Expand All @@ -40,12 +41,11 @@ LIBC_INLINE void heap_sort(const Array &array) {
while (left_child(root) < end) {
size_t child = left_child(root);
// If there are two children, set child to the greater.
if (child + 1 < end &&
array.elem_compare(child, array.get(child + 1)) < 0)
if ((child + 1 < end) && is_less(array.get(child), array.get(child + 1)))
++child;

// If the root is less than the greater child
if (array.elem_compare(root, array.get(child)) >= 0)
if (!is_less(array.get(root), array.get(child)))
break;

// Swap the root with the greater child and continue sifting down.
Expand Down
10 changes: 4 additions & 6 deletions libc/src/stdlib/qsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ namespace LIBC_NAMESPACE_DECL {
LLVM_LIBC_FUNCTION(void, qsort,
(void *array, size_t array_size, size_t elem_size,
int (*compare)(const void *, const void *))) {
if (array == nullptr || array_size == 0 || elem_size == 0)
return;
internal::Comparator c(compare);

auto arr = internal::Array(reinterpret_cast<uint8_t *>(array), array_size,
elem_size, c);
const auto is_less = [compare](const void *a, const void *b) -> bool {
return compare(a, b) < 0;
};

internal::sort(arr);
internal::unstable_sort(array, array_size, elem_size, is_less);
}

} // namespace LIBC_NAMESPACE_DECL
171 changes: 101 additions & 70 deletions libc/src/stdlib/qsort_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,91 +17,122 @@
namespace LIBC_NAMESPACE_DECL {
namespace internal {

using Compare = int(const void *, const void *);
using CompareWithState = int(const void *, const void *, void *);

enum class CompType { COMPARE, COMPARE_WITH_STATE };

struct Comparator {
union {
Compare *comp_func;
CompareWithState *comp_func_r;
};
const CompType comp_type;

void *arg;

Comparator(Compare *func)
: comp_func(func), comp_type(CompType::COMPARE), arg(nullptr) {}

Comparator(CompareWithState *func, void *arg_val)
: comp_func_r(func), comp_type(CompType::COMPARE_WITH_STATE),
arg(arg_val) {}

#if defined(__clang__)
// Recent upstream changes to -fsanitize=function find more instances of
// function type mismatches. One case is with the comparator passed to this
// class. Libraries will tend to pass comparators that take pointers to
// varying types while this comparator expects to accept const void pointers.
// Ideally those tools would pass a function that strictly accepts const
// void*s to avoid UB, or would use qsort_r to pass their own comparator.
[[clang::no_sanitize("function")]]
#endif
int comp_vals(const void *a, const void *b) const {
if (comp_type == CompType::COMPARE) {
return comp_func(a, b);
} else {
return comp_func_r(a, b, arg);
class ArrayGenericSize {
cpp::byte *array_base;
size_t array_len;
size_t elem_size;

LIBC_INLINE cpp::byte *get_internal(size_t i) const {
return array_base + (i * elem_size);
}

public:
LIBC_INLINE ArrayGenericSize(void *a, size_t s, size_t e)
: array_base(reinterpret_cast<cpp::byte *>(a)), array_len(s),
elem_size(e) {}

static constexpr bool has_fixed_size() { return false; }

LIBC_INLINE void *get(size_t i) const { return get_internal(i); }

LIBC_INLINE void swap(size_t i, size_t j) const {
// It's possible to use 8 byte blocks with `uint64_t`, but that
// generates more machine code as the remainder loop gets
// unrolled, plus 4 byte operations are more likely to be
// efficient on a wider variety of hardware. On x86 LLVM tends
// to unroll the block loop again into 2 16 byte swaps per
// iteration which is another reason that 4 byte blocks yields
// good performance even for big types.
using block_t = uint32_t;
constexpr size_t BLOCK_SIZE = sizeof(block_t);

alignas(block_t) cpp::byte tmp_block[BLOCK_SIZE];

cpp::byte *elem_i = get_internal(i);
cpp::byte *elem_j = get_internal(j);

const size_t elem_size_rem = elem_size % BLOCK_SIZE;
const cpp::byte *elem_i_block_end = elem_i + (elem_size - elem_size_rem);

while (elem_i != elem_i_block_end) {
__builtin_memcpy(tmp_block, elem_i, BLOCK_SIZE);
__builtin_memcpy(elem_i, elem_j, BLOCK_SIZE);
__builtin_memcpy(elem_j, tmp_block, BLOCK_SIZE);

elem_i += BLOCK_SIZE;
elem_j += BLOCK_SIZE;
}

for (size_t n = 0; n < elem_size_rem; ++n) {
cpp::byte tmp = elem_i[n];
elem_i[n] = elem_j[n];
elem_j[n] = tmp;
}
}

LIBC_INLINE size_t len() const { return array_len; }

// Make an Array starting at index |i| and length |s|.
LIBC_INLINE ArrayGenericSize make_array(size_t i, size_t s) const {
return ArrayGenericSize(get_internal(i), s, elem_size);
}

// Reset this Array to point at a different interval of the same
// items starting at index |i|.
LIBC_INLINE void reset_bounds(size_t i, size_t s) {
array_base = get_internal(i);
array_len = s;
}
};

class Array {
uint8_t *array;
size_t array_size;
size_t elem_size;
Comparator compare;
// Having a specialized Array type for sorting that knows at
// compile-time what the size of the element is, allows for much more
// efficient swapping and for cheaper offset calculations.
template <size_t ELEM_SIZE> class ArrayFixedSize {
cpp::byte *array_base;
size_t array_len;

public:
Array(uint8_t *a, size_t s, size_t e, Comparator c)
: array(a), array_size(s), elem_size(e), compare(c) {}

uint8_t *get(size_t i) const { return array + i * elem_size; }

void swap(size_t i, size_t j) const {
uint8_t *elem_i = get(i);
uint8_t *elem_j = get(j);
for (size_t b = 0; b < elem_size; ++b) {
uint8_t temp = elem_i[b];
elem_i[b] = elem_j[b];
elem_j[b] = temp;
}
LIBC_INLINE cpp::byte *get_internal(size_t i) const {
return array_base + (i * ELEM_SIZE);
}

int elem_compare(size_t i, const uint8_t *other) const {
// An element must compare equal to itself so we don't need to consult the
// user provided comparator.
if (get(i) == other)
return 0;
return compare.comp_vals(get(i), other);
public:
LIBC_INLINE ArrayFixedSize(void *a, size_t s)
: array_base(reinterpret_cast<cpp::byte *>(a)), array_len(s) {}

// Beware this function is used a heuristic for cheap to swap types, so
// instantiating `ArrayFixedSize` with `ELEM_SIZE > 100` is probably a bad
// idea perf wise.
static constexpr bool has_fixed_size() { return true; }

LIBC_INLINE void *get(size_t i) const { return get_internal(i); }

LIBC_INLINE void swap(size_t i, size_t j) const {
alignas(32) cpp::byte tmp[ELEM_SIZE];

cpp::byte *elem_i = get_internal(i);
cpp::byte *elem_j = get_internal(j);

__builtin_memcpy(tmp, elem_i, ELEM_SIZE);
__builtin_memmove(elem_i, elem_j, ELEM_SIZE);
__builtin_memcpy(elem_j, tmp, ELEM_SIZE);
}

size_t size() const { return array_size; }
LIBC_INLINE size_t len() const { return array_len; }

// Make an Array starting at index |i| and size |s|.
LIBC_INLINE Array make_array(size_t i, size_t s) const {
return Array(get(i), s, elem_size, compare);
// Make an Array starting at index |i| and length |s|.
LIBC_INLINE ArrayFixedSize<ELEM_SIZE> make_array(size_t i, size_t s) const {
return ArrayFixedSize<ELEM_SIZE>(get_internal(i), s);
}

// Reset this Array to point at a different interval of the same items.
LIBC_INLINE void reset_bounds(uint8_t *a, size_t s) {
array = a;
array_size = s;
// Reset this Array to point at a different interval of the same
// items starting at index |i|.
LIBC_INLINE void reset_bounds(size_t i, size_t s) {
array_base = get_internal(i);
array_len = s;
}
};

using SortingRoutine = void(const Array &);

} // namespace internal
} // namespace LIBC_NAMESPACE_DECL

Expand Down
85 changes: 85 additions & 0 deletions libc/src/stdlib/qsort_pivot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===-- Implementation header for qsort utilities ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_STDLIB_QSORT_PIVOT_H
#define LLVM_LIBC_SRC_STDLIB_QSORT_PIVOT_H

#include <stdint.h>

namespace LIBC_NAMESPACE_DECL {
namespace internal {

// Recursively select a pseudomedian if above this threshold.
constexpr size_t PSEUDO_MEDIAN_REC_THRESHOLD = 64;

// Selects a pivot from `array`. Algorithm taken from glidesort by Orson Peters.
//
// This chooses a pivot by sampling an adaptive amount of points, approximating
// the quality of a median of sqrt(n) elements.
template <typename A, typename F>
size_t choose_pivot(const A &array, const F &is_less) {
const size_t len = array.len();

if (len < 8) {
return 0;
}

const size_t len_div_8 = len / 8;

const size_t a = 0; // [0, floor(n/8))
const size_t b = len_div_8 * 4; // [4*floor(n/8), 5*floor(n/8))
const size_t c = len_div_8 * 7; // [7*floor(n/8), 8*floor(n/8))

if (len < PSEUDO_MEDIAN_REC_THRESHOLD)
return median3(array, a, b, c, is_less);
else
return median3_rec(array, a, b, c, len_div_8, is_less);
}

// Calculates an approximate median of 3 elements from sections a, b, c, or
// recursively from an approximation of each, if they're large enough. By
// dividing the size of each section by 8 when recursing we have logarithmic
// recursion depth and overall sample from f(n) = 3*f(n/8) -> f(n) =
// O(n^(log(3)/log(8))) ~= O(n^0.528) elements.
template <typename A, typename F>
size_t median3_rec(const A &array, size_t a, size_t b, size_t c, size_t n,
const F &is_less) {
if (n * 8 >= PSEUDO_MEDIAN_REC_THRESHOLD) {
const size_t n8 = n / 8;
a = median3_rec(array, a, a + (n8 * 4), a + (n8 * 7), n8, is_less);
b = median3_rec(array, b, b + (n8 * 4), b + (n8 * 7), n8, is_less);
c = median3_rec(array, c, c + (n8 * 4), c + (n8 * 7), n8, is_less);
}
return median3(array, a, b, c, is_less);
}

/// Calculates the median of 3 elements.
template <typename A, typename F>
size_t median3(const A &array, size_t a, size_t b, size_t c, const F &is_less) {
const void *a_ptr = array.get(a);
const void *b_ptr = array.get(b);
const void *c_ptr = array.get(c);

const bool x = is_less(a_ptr, b_ptr);
const bool y = is_less(a_ptr, c_ptr);
if (x == y) {
// If x=y=0 then b, c <= a. In this case we want to return max(b, c).
// If x=y=1 then a < b, c. In this case we want to return min(b, c).
// By toggling the outcome of b < c using XOR x we get this behavior.
const bool z = is_less(b_ptr, c_ptr);
return z ^ x ? c : b;
} else {
// Either c <= a < b or b <= a < c, thus a is our median.
return a;
}
}

} // namespace internal
} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC_STDLIB_QSORT_PIVOT_H
11 changes: 5 additions & 6 deletions libc/src/stdlib/qsort_r.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ LLVM_LIBC_FUNCTION(void, qsort_r,
(void *array, size_t array_size, size_t elem_size,
int (*compare)(const void *, const void *, void *),
void *arg)) {
if (array == nullptr || array_size == 0 || elem_size == 0)
return;
internal::Comparator c(compare, arg);
auto arr = internal::Array(reinterpret_cast<uint8_t *>(array), array_size,
elem_size, c);

internal::sort(arr);
const auto is_less = [compare, arg](const void *a, const void *b) -> bool {
return compare(a, b, arg) < 0;
};

internal::unstable_sort(array, array_size, elem_size, is_less);
}

} // namespace LIBC_NAMESPACE_DECL
Loading

0 comments on commit d2e71c9

Please sign in to comment.