Skip to content

Commit bf7f733

Browse files
committed
Implement partial sorting algorithms
Each datatype now supports two partial sorting algorithms: 1) Sort such that a particular index is valid, and 2) Sort such that a range of indices is valid, where 'valid' means that the kth smallest element is in position k. Additionally transferred a few lingering comments from a refactor earlier in the project.
1 parent aeab737 commit bf7f733

File tree

6 files changed

+208
-10
lines changed

6 files changed

+208
-10
lines changed

src/avx512-16bit-common.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,38 @@ X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
259259
return ((type_t *)&sort)[16];
260260
}
261261

262+
template <typename vtype, typename type_t>
263+
static void
264+
qsort_partial_16bit_(int64_t k, type_t *arr,
265+
int64_t left, int64_t right,
266+
int64_t max_iters)
267+
{
268+
/*
269+
* Resort to std::sort if quicksort isnt making any progress
270+
*/
271+
if (max_iters <= 0) {
272+
std::sort(arr + left, arr + right + 1, comparison_func<vtype>);
273+
return;
274+
}
275+
/*
276+
* Base case: use bitonic networks to sort arrays <= 128
277+
*/
278+
if (right + 1 - left <= 128) {
279+
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
280+
return;
281+
}
282+
283+
type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
284+
type_t smallest = vtype::type_max();
285+
type_t biggest = vtype::type_min();
286+
int64_t pivot_index = partition_avx512<vtype>(
287+
arr, left, right + 1, pivot, &smallest, &biggest);
288+
if ((pivot != smallest) && (k <= pivot_index))
289+
qsort_partial_16bit_<vtype>(k, arr, left, pivot_index - 1, max_iters - 1);
290+
else if ((pivot != biggest) && (k > pivot_index))
291+
qsort_partial_16bit_<vtype>(k, arr, pivot_index, right, max_iters - 1);
292+
}
293+
262294
template <typename vtype, typename type_t>
263295
static void
264296
qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)

src/avx512-16bit-qsort.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,34 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
405405
}
406406
}
407407

408+
template <>
409+
void avx512_qsort_partial(int64_t k, int16_t *arr, int64_t arrsize)
410+
{
411+
if (arrsize > 1) {
412+
qsort_partial_16bit_<zmm_vector<int16_t>, int16_t>(
413+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
414+
}
415+
}
416+
417+
template <>
418+
void avx512_qsort_partial(int64_t k, uint16_t *arr, int64_t arrsize)
419+
{
420+
if (arrsize > 1) {
421+
qsort_partial_16bit_<zmm_vector<uint16_t>, uint16_t>(
422+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
423+
}
424+
}
425+
426+
void avx512_qsort_fp16_partial(int64_t k, uint16_t *arr, int64_t arrsize)
427+
{
428+
if (arrsize > 1) {
429+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
430+
qsort_partial_16bit_<zmm_vector<float16>, uint16_t>(
431+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432+
replace_inf_with_nan(arr, arrsize, nan_count);
433+
}
434+
}
435+
408436
template <>
409437
void avx512_qsort(int16_t *arr, int64_t arrsize)
410438
{
@@ -432,4 +460,5 @@ void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
432460
replace_inf_with_nan(arr, arrsize, nan_count);
433461
}
434462
}
463+
435464
#endif // AVX512_QSORT_16BIT

src/avx512-32bit-qsort.hpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,38 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
626626
return ((type_t *)&sort)[8];
627627
}
628628

629+
template <typename vtype, typename type_t>
630+
static void
631+
qsort_partial_32bit_(int64_t k, type_t *arr,
632+
int64_t left, int64_t right,
633+
int64_t max_iters)
634+
{
635+
/*
636+
* Resort to std::sort if quicksort isnt making any progress
637+
*/
638+
if (max_iters <= 0) {
639+
std::sort(arr + left, arr + right + 1);
640+
return;
641+
}
642+
/*
643+
* Base case: use bitonic networks to sort arrays <= 128
644+
*/
645+
if (right + 1 - left <= 128) {
646+
sort_128_32bit<vtype>(arr + left, (int32_t)(right + 1 - left));
647+
return;
648+
}
649+
650+
type_t pivot = get_pivot_32bit<vtype>(arr, left, right);
651+
type_t smallest = vtype::type_max();
652+
type_t biggest = vtype::type_min();
653+
int64_t pivot_index = partition_avx512<vtype>(
654+
arr, left, right + 1, pivot, &smallest, &biggest);
655+
if ((pivot != smallest) && (k <= pivot_index))
656+
qsort_partial_32bit_<vtype>(k, arr, left, pivot_index - 1, max_iters - 1);
657+
else if ((pivot != biggest) && (k > pivot_index))
658+
qsort_partial_32bit_<vtype>(k, arr, pivot_index, right, max_iters - 1);
659+
}
660+
629661
template <typename vtype, typename type_t>
630662
static void
631663
qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -681,6 +713,35 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
681713
}
682714
}
683715

716+
template <>
717+
void avx512_qsort_partial<int32_t>(int64_t k, int32_t *arr, int64_t arrsize)
718+
{
719+
if (arrsize > 1) {
720+
qsort_partial_32bit_<zmm_vector<int32_t>, int32_t>(
721+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
722+
}
723+
}
724+
725+
template <>
726+
void avx512_qsort_partial<uint32_t>(int64_t k, uint32_t *arr, int64_t arrsize)
727+
{
728+
if (arrsize > 1) {
729+
qsort_partial_32bit_<zmm_vector<uint32_t>, uint32_t>(
730+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
731+
}
732+
}
733+
734+
template <>
735+
void avx512_qsort_partial<float>(int64_t k, float *arr, int64_t arrsize)
736+
{
737+
if (arrsize > 1) {
738+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
739+
qsort_partial_32bit_<zmm_vector<float>, float>(
740+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
741+
replace_inf_with_nan(arr, arrsize, nan_count);
742+
}
743+
}
744+
684745
template <>
685746
void avx512_qsort<int32_t>(int32_t *arr, int64_t arrsize)
686747
{

src/avx512-64bit-common.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
* Authors: Raghuveer Devulapalli <raghuveer.devulapalli@intel.com>
55
* ****************************************************************/
66

7-
#ifndef AVX512_64BIT_COMMOM
8-
#define AVX512_64BIT_COMMOM
7+
#ifndef AVX512_64BIT_COMMON
8+
#define AVX512_64BIT_COMMON
99
#include "avx512-common-qsort.h"
1010

11+
/*
12+
* Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic
13+
* sorting network (see
14+
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
15+
*/
16+
// ZMM 7, 6, 5, 4, 3, 2, 1, 0
1117
#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3
1218
#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7
1319
#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2

src/avx512-64bit-qsort.hpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99

1010
#include "avx512-64bit-common.h"
1111

12-
/*
13-
* Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic
14-
* sorting network (see
15-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
16-
*/
17-
// ZMM 7, 6, 5, 4, 3, 2, 1, 0
18-
1912
// Assumes zmm is bitonic and performs a recursive half cleaner
2013
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
2114
X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_64bit(zmm_t zmm)
@@ -408,6 +401,67 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
408401
qsort_64bit_<vtype>(arr, pivot_index, right, max_iters - 1);
409402
}
410403

404+
template <typename vtype, typename type_t>
405+
static void
406+
qsort_partial_64bit_(int64_t k, type_t *arr,
407+
int64_t left, int64_t right,
408+
int64_t max_iters)
409+
{
410+
/*
411+
* Resort to std::sort if quicksort isnt making any progress
412+
*/
413+
if (max_iters <= 0) {
414+
std::sort(arr + left, arr + right + 1);
415+
return;
416+
}
417+
/*
418+
* Base case: use bitonic networks to sort arrays <= 128
419+
*/
420+
if (right + 1 - left <= 128) {
421+
sort_128_64bit<vtype>(arr + left, (int32_t)(right + 1 - left));
422+
return;
423+
}
424+
425+
type_t pivot = get_pivot_64bit<vtype>(arr, left, right);
426+
type_t smallest = vtype::type_max();
427+
type_t biggest = vtype::type_min();
428+
int64_t pivot_index = partition_avx512<vtype>(
429+
arr, left, right + 1, pivot, &smallest, &biggest);
430+
if ((pivot != smallest) && (k <= pivot_index))
431+
qsort_partial_64bit_<vtype>(k, arr, left, pivot_index - 1, max_iters - 1);
432+
else if ((pivot != biggest) && (k > pivot_index))
433+
qsort_partial_64bit_<vtype>(k, arr, pivot_index, right, max_iters - 1);
434+
}
435+
436+
template <>
437+
void avx512_qsort_partial<int64_t>(int64_t k, int64_t *arr, int64_t arrsize)
438+
{
439+
if (arrsize > 1) {
440+
qsort_partial_64bit_<zmm_vector<int64_t>, int64_t>(
441+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
442+
}
443+
}
444+
445+
template <>
446+
void avx512_qsort_partial<uint64_t>(int64_t k, uint64_t *arr, int64_t arrsize)
447+
{
448+
if (arrsize > 1) {
449+
qsort_partial_64bit_<zmm_vector<uint64_t>, uint64_t>(
450+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
451+
}
452+
}
453+
454+
template <>
455+
void avx512_qsort_partial<double>(int64_t k, double *arr, int64_t arrsize)
456+
{
457+
if (arrsize > 1) {
458+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
459+
qsort_partial_64bit_<zmm_vector<double>, double>(
460+
k, arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
461+
replace_inf_with_nan(arr, arrsize, nan_count);
462+
}
463+
}
464+
411465
template <>
412466
void avx512_qsort<int64_t>(int64_t *arr, int64_t arrsize)
413467
{

src/avx512-common-qsort.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,25 @@ struct zmm_vector;
8888

8989
template <typename T>
9090
void avx512_qsort(T *arr, int64_t arrsize);
91-
9291
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
9392

93+
template <typename T>
94+
void avx512_qsort_partial(int64_t k, T *arr, int64_t arrsize);
95+
void avx512_qsort_fp16_partial(int64_t k, uint16_t *arr, int64_t arrsize);
96+
97+
template <typename T>
98+
void avx512_qsort_partialrange(int64_t kfrom, int64_t kto, T *arr, int64_t arrsize) {
99+
avx512_qsort_partial<T>(kto, arr, arrsize);
100+
avx512_qsort_partial<T>(kfrom, arr, kto);
101+
avx512_qsort<T>(arr + kfrom, kto - kfrom);
102+
}
103+
inline void avx512_qsort_fp16_partialrange(int64_t kfrom, int64_t kto, uint16_t *arr, int64_t arrsize)
104+
{
105+
avx512_qsort_fp16_partial(kto, arr, arrsize);
106+
avx512_qsort_fp16_partial(kfrom, arr, kto);
107+
avx512_qsort_fp16(arr + kfrom, kto - kfrom);
108+
}
109+
94110
template <typename vtype, typename T = typename vtype::type_t>
95111
bool comparison_func(const T &a, const T &b)
96112
{

0 commit comments

Comments
 (0)