Skip to content

Commit 70424a6

Browse files
committed
Changed quicksort and quickselect to use template based sorting networks
1 parent 42c672f commit 70424a6

File tree

7 files changed

+310
-1001
lines changed

7 files changed

+310
-1001
lines changed

src/avx512-16bit-common.h

Lines changed: 3 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define AVX512_16BIT_COMMON
99

1010
#include "avx512-common-qsort.h"
11+
#include "xss-network-qsort.hpp"
1112

1213
/*
1314
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
@@ -118,103 +119,6 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
118119
return zmm;
119120
}
120121

121-
// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
122-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
123-
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
124-
{
125-
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
126-
zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
127-
zmm_t zmm3 = vtype::min(zmm1, zmm2);
128-
zmm_t zmm4 = vtype::max(zmm1, zmm2);
129-
// 2) Recursive half cleaner for each
130-
zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
131-
zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
132-
}
133-
134-
// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
135-
// half cleaner
136-
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
137-
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
138-
{
139-
zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
140-
zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
141-
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
142-
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
143-
zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4),
144-
vtype::max(zmm[1], zmm2r));
145-
zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4),
146-
vtype::max(zmm[0], zmm3r));
147-
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
148-
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
149-
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
150-
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
151-
zmm[0] = bitonic_merge_zmm_16bit<vtype>(zmm0);
152-
zmm[1] = bitonic_merge_zmm_16bit<vtype>(zmm1);
153-
zmm[2] = bitonic_merge_zmm_16bit<vtype>(zmm2);
154-
zmm[3] = bitonic_merge_zmm_16bit<vtype>(zmm3);
155-
}
156-
157-
template <typename vtype, typename type_t>
158-
X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N)
159-
{
160-
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF;
161-
typename vtype::zmm_t zmm
162-
= vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
163-
vtype::mask_storeu(arr, load_mask, sort_zmm_16bit<vtype>(zmm));
164-
}
165-
166-
template <typename vtype, typename type_t>
167-
X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N)
168-
{
169-
if (N <= 32) {
170-
sort_32_16bit<vtype>(arr, N);
171-
return;
172-
}
173-
using zmm_t = typename vtype::zmm_t;
174-
typename vtype::opmask_t load_mask
175-
= ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF;
176-
zmm_t zmm1 = vtype::loadu(arr);
177-
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32);
178-
zmm1 = sort_zmm_16bit<vtype>(zmm1);
179-
zmm2 = sort_zmm_16bit<vtype>(zmm2);
180-
bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
181-
vtype::storeu(arr, zmm1);
182-
vtype::mask_storeu(arr + 32, load_mask, zmm2);
183-
}
184-
185-
template <typename vtype, typename type_t>
186-
X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N)
187-
{
188-
if (N <= 64) {
189-
sort_64_16bit<vtype>(arr, N);
190-
return;
191-
}
192-
using zmm_t = typename vtype::zmm_t;
193-
using opmask_t = typename vtype::opmask_t;
194-
zmm_t zmm[4];
195-
zmm[0] = vtype::loadu(arr);
196-
zmm[1] = vtype::loadu(arr + 32);
197-
opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF;
198-
if (N != 128) {
199-
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
200-
load_mask1 = combined_mask & 0xFFFFFFFF;
201-
load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF;
202-
}
203-
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
204-
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96);
205-
zmm[0] = sort_zmm_16bit<vtype>(zmm[0]);
206-
zmm[1] = sort_zmm_16bit<vtype>(zmm[1]);
207-
zmm[2] = sort_zmm_16bit<vtype>(zmm[2]);
208-
zmm[3] = sort_zmm_16bit<vtype>(zmm[3]);
209-
bitonic_merge_two_zmm_16bit<vtype>(zmm[0], zmm[1]);
210-
bitonic_merge_two_zmm_16bit<vtype>(zmm[2], zmm[3]);
211-
bitonic_merge_four_zmm_16bit<vtype>(zmm);
212-
vtype::storeu(arr, zmm[0]);
213-
vtype::storeu(arr + 32, zmm[1]);
214-
vtype::mask_storeu(arr + 64, load_mask1, zmm[2]);
215-
vtype::mask_storeu(arr + 96, load_mask2, zmm[3]);
216-
}
217-
218122
template <typename vtype, typename type_t>
219123
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
220124
const int64_t left,
@@ -274,7 +178,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
274178
* Base case: use bitonic networks to sort arrays <= 128
275179
*/
276180
if (right + 1 - left <= 128) {
277-
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
181+
xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left));
278182
return;
279183
}
280184

@@ -307,7 +211,7 @@ static void qselect_16bit_(type_t *arr,
307211
* Base case: use bitonic networks to sort arrays <= 128
308212
*/
309213
if (right + 1 - left <= 128) {
310-
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
214+
xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left));
311215
return;
312216
}
313217

src/avx512-16bit-qsort.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define AVX512_QSORT_16BIT
99

1010
#include "avx512-16bit-common.h"
11+
#include "xss-network-qsort.hpp"
1112

1213
struct float16 {
1314
uint16_t val;
@@ -152,6 +153,19 @@ struct zmm_vector<float16> {
152153
{
153154
return _mm512_storeu_si512(mem, x);
154155
}
156+
static zmm_t reverse(zmm_t zmm)
157+
{
158+
const auto rev_index = get_network(4);
159+
return permutexvar(rev_index, zmm);
160+
}
161+
static zmm_t bitonic_merge(zmm_t x)
162+
{
163+
return bitonic_merge_zmm_16bit<zmm_vector<float16>>(x);
164+
}
165+
static zmm_t sort_vec(zmm_t x)
166+
{
167+
return sort_zmm_16bit<zmm_vector<float16>>(x);
168+
}
155169
};
156170

157171
template <>
@@ -251,6 +265,19 @@ struct zmm_vector<int16_t> {
251265
{
252266
return _mm512_storeu_si512(mem, x);
253267
}
268+
static zmm_t reverse(zmm_t zmm)
269+
{
270+
const auto rev_index = get_network(4);
271+
return permutexvar(rev_index, zmm);
272+
}
273+
static zmm_t bitonic_merge(zmm_t x)
274+
{
275+
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
276+
}
277+
static zmm_t sort_vec(zmm_t x)
278+
{
279+
return sort_zmm_16bit<zmm_vector<type_t>>(x);
280+
}
254281
};
255282
template <>
256283
struct zmm_vector<uint16_t> {
@@ -347,6 +374,19 @@ struct zmm_vector<uint16_t> {
347374
{
348375
return _mm512_storeu_si512(mem, x);
349376
}
377+
static zmm_t reverse(zmm_t zmm)
378+
{
379+
const auto rev_index = get_network(4);
380+
return permutexvar(rev_index, zmm);
381+
}
382+
static zmm_t bitonic_merge(zmm_t x)
383+
{
384+
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
385+
}
386+
static zmm_t sort_vec(zmm_t x)
387+
{
388+
return sort_zmm_16bit<zmm_vector<type_t>>(x);
389+
}
350390
};
351391

352392
template <>

0 commit comments

Comments
 (0)