|
9 | 9 |
|
10 | 10 | #include "avx512-64bit-common.h" |
11 | 11 |
|
12 | | -/* |
13 | | - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic |
14 | | - * sorting network (see |
15 | | - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) |
16 | | - */ |
17 | | -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 |
18 | | - |
19 | 12 | // Assumes zmm is bitonic and performs a recursive half cleaner |
20 | 13 | template <typename vtype, typename zmm_t = typename vtype::zmm_t> |
21 | 14 | X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm) |
@@ -408,6 +401,67 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) |
408 | 401 | qsort_64bit_<vtype>(arr, pivot_index, right, max_iters - 1); |
409 | 402 | } |
410 | 403 |
|
| 404 | +template <typename vtype, typename type_t> |
| 405 | +static void |
| 406 | +qsort_partial_64bit_(int64_t k, type_t *arr, |
| 407 | + int64_t left, int64_t right, |
| 408 | + int64_t max_iters) |
| 409 | +{ |
| 410 | + /* |
| 411 | + * Resort to std::sort if quicksort isnt making any progress |
| 412 | + */ |
| 413 | + if (max_iters <= 0) { |
| 414 | + std::sort(arr + left, arr + right + 1); |
| 415 | + return; |
| 416 | + } |
| 417 | + /* |
| 418 | + * Base case: use bitonic networks to sort arrays <= 128 |
| 419 | + */ |
| 420 | + if (right + 1 - left <= 128) { |
| 421 | + sort_128_64bit<vtype>(arr + left, (int32_t)(right + 1 - left)); |
| 422 | + return; |
| 423 | + } |
| 424 | + |
| 425 | + type_t pivot = get_pivot_64bit<vtype>(arr, left, right); |
| 426 | + type_t smallest = vtype::type_max(); |
| 427 | + type_t biggest = vtype::type_min(); |
| 428 | + int64_t pivot_index = partition_avx512<vtype>( |
| 429 | + arr, left, right + 1, pivot, &smallest, &biggest); |
| 430 | + if ((pivot != smallest) && (k <= pivot_index)) |
| 431 | + qsort_partial_64bit_<vtype>(k, arr, left, pivot_index - 1, max_iters - 1); |
| 432 | + else if ((pivot != biggest) && (k > pivot_index)) |
| 433 | + qsort_partial_64bit_<vtype>(k, arr, pivot_index, right, max_iters - 1); |
| 434 | +} |
| 435 | + |
| 436 | +template <> |
| 437 | +void avx512_qsort_partial<int64_t>(int64_t k, int64_t *arr, int64_t arrsize) |
| 438 | +{ |
| 439 | + if (arrsize > 1) { |
| 440 | + qsort_partial_64bit_<zmm_vector<int64_t>, int64_t>( |
| 441 | + k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); |
| 442 | + } |
| 443 | +} |
| 444 | + |
| 445 | +template <> |
| 446 | +void avx512_qsort_partial<uint64_t>(int64_t k, uint64_t *arr, int64_t arrsize) |
| 447 | +{ |
| 448 | + if (arrsize > 1) { |
| 449 | + qsort_partial_64bit_<zmm_vector<uint64_t>, uint64_t>( |
| 450 | + k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); |
| 451 | + } |
| 452 | +} |
| 453 | + |
| 454 | +template <> |
| 455 | +void avx512_qsort_partial<double>(int64_t k, double *arr, int64_t arrsize) |
| 456 | +{ |
| 457 | + if (arrsize > 1) { |
| 458 | + int64_t nan_count = replace_nan_with_inf(arr, arrsize); |
| 459 | + qsort_partial_64bit_<zmm_vector<double>, double>( |
| 460 | + k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); |
| 461 | + replace_inf_with_nan(arr, arrsize, nan_count); |
| 462 | + } |
| 463 | +} |
| 464 | + |
411 | 465 | template <> |
412 | 466 | void avx512_qsort<int64_t>(int64_t *arr, int64_t arrsize) |
413 | 467 | { |
|
0 commit comments