Skip to content

Commit b52e889

Browse files
committed
Changed the method used for the small sort
1 parent 3766740 commit b52e889

File tree

8 files changed

+747
-159
lines changed

8 files changed

+747
-159
lines changed

src/avx512-16bit-common.h

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -93,30 +93,74 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm)
9393
return zmm;
9494
}
9595

96-
// Assumes zmm is bitonic and performs a recursive half cleaner
97-
template <typename vtype, typename reg_t = typename vtype::reg_t>
98-
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit(reg_t zmm)
99-
{
100-
// 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
101-
zmm = cmp_merge<vtype>(
102-
zmm, vtype::permutexvar(vtype::get_network(6), zmm), 0xFFFF0000);
103-
// 2) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
104-
zmm = cmp_merge<vtype>(
105-
zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00);
106-
// 3) half_cleaner[8]
107-
zmm = cmp_merge<vtype>(
108-
zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0);
109-
// 3) half_cleaner[4]
110-
zmm = cmp_merge<vtype>(
111-
zmm,
112-
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
113-
0xCCCCCCCC);
114-
// 3) half_cleaner[2]
115-
zmm = cmp_merge<vtype>(
116-
zmm,
117-
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
118-
0xAAAAAAAA);
119-
return zmm;
120-
}
96+
struct avx512_16bit_swizzle_ops{
97+
template <typename vtype, int scale>
98+
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){
99+
__m512i v = vtype::cast_to(reg);
100+
101+
if constexpr (scale == 2){
102+
__m512i mask = _mm512_set_epi16(30, 31, 28, 29, 26, 27, 24, 25, 22, 23, 20, 21, 18, 19, 16, 17, 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
103+
v = _mm512_permutexvar_epi16(mask, v);
104+
}else if constexpr (scale == 4){
105+
v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001);
106+
}else if constexpr (scale == 8){
107+
v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110);
108+
}else if constexpr (scale == 16){
109+
v = _mm512_shuffle_i64x2(v, v, 0b10110001);
110+
}else if constexpr (scale == 32){
111+
v = _mm512_shuffle_i64x2(v, v, 0b01001110);
112+
}else{
113+
static_assert(scale == -1, "should not be reached");
114+
}
115+
116+
return vtype::cast_from(v);
117+
}
118+
119+
template <typename vtype, int scale>
120+
X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n(typename vtype::reg_t reg){
121+
__m512i v = vtype::cast_to(reg);
122+
123+
if constexpr (scale == 2){
124+
return swap_n<vtype, 2>(reg);
125+
}else if constexpr (scale == 4){
126+
__m512i mask = _mm512_set_epi16(28, 29, 30, 31, 24, 25, 26, 27, 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3);
127+
v = _mm512_permutexvar_epi16(mask, v);
128+
}else if constexpr (scale == 8){
129+
__m512i mask = _mm512_set_epi16(24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7);
130+
v = _mm512_permutexvar_epi16(mask, v);
131+
}else if constexpr (scale == 16){
132+
__m512i mask = _mm512_set_epi16(16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
133+
v = _mm512_permutexvar_epi16(mask, v);
134+
}else if constexpr (scale == 32){
135+
return vtype::reverse(reg);
136+
}else{
137+
static_assert(scale == -1, "should not be reached");
138+
}
139+
140+
return vtype::cast_from(v);
141+
}
142+
143+
template <typename vtype, int scale>
144+
X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n(typename vtype::reg_t reg, typename vtype::reg_t other){
145+
__m512i v1 = vtype::cast_to(reg);
146+
__m512i v2 = vtype::cast_to(other);
147+
148+
if constexpr (scale == 2){
149+
v1 = _mm512_mask_blend_epi16(0b01010101010101010101010101010101, v1, v2);
150+
}else if constexpr (scale == 4){
151+
v1 = _mm512_mask_blend_epi16(0b00110011001100110011001100110011, v1, v2);
152+
}else if constexpr (scale == 8){
153+
v1 = _mm512_mask_blend_epi16(0b00001111000011110000111100001111, v1, v2);
154+
}else if constexpr (scale == 16){
155+
v1 = _mm512_mask_blend_epi16(0b00000000111111110000000011111111, v1, v2);
156+
}else if constexpr (scale == 32){
157+
v1 = _mm512_mask_blend_epi16(0b00000000000000001111111111111111, v1, v2);
158+
}else{
159+
static_assert(scale == -1, "should not be reached");
160+
}
161+
162+
return vtype::cast_from(v1);
163+
}
164+
};
121165

122166
#endif // AVX512_16BIT_COMMON

src/avx512-16bit-qsort.hpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct zmm_vector<float16> {
2222
static const uint8_t numlanes = 32;
2323
static constexpr int network_sort_threshold = 512;
2424
static constexpr int partition_unroll_factor = 0;
25+
26+
using swizzle_ops = avx512_16bit_swizzle_ops;
2527

2628
static reg_t get_network(int index)
2729
{
@@ -159,14 +161,16 @@ struct zmm_vector<float16> {
159161
const auto rev_index = get_network(4);
160162
return permutexvar(rev_index, zmm);
161163
}
162-
static reg_t bitonic_merge(reg_t x)
163-
{
164-
return bitonic_merge_zmm_16bit<zmm_vector<float16>>(x);
165-
}
166164
static reg_t sort_vec(reg_t x)
167165
{
168166
return sort_zmm_16bit<zmm_vector<float16>>(x);
169167
}
168+
static reg_t cast_from(__m512i v){
169+
return v;
170+
}
171+
static __m512i cast_to(reg_t v){
172+
return v;
173+
}
170174
};
171175

172176
template <>
@@ -178,6 +182,8 @@ struct zmm_vector<int16_t> {
178182
static const uint8_t numlanes = 32;
179183
static constexpr int network_sort_threshold = 512;
180184
static constexpr int partition_unroll_factor = 0;
185+
186+
using swizzle_ops = avx512_16bit_swizzle_ops;
181187

182188
static reg_t get_network(int index)
183189
{
@@ -273,14 +279,16 @@ struct zmm_vector<int16_t> {
273279
const auto rev_index = get_network(4);
274280
return permutexvar(rev_index, zmm);
275281
}
276-
static reg_t bitonic_merge(reg_t x)
277-
{
278-
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
279-
}
280282
static reg_t sort_vec(reg_t x)
281283
{
282284
return sort_zmm_16bit<zmm_vector<type_t>>(x);
283285
}
286+
static reg_t cast_from(__m512i v){
287+
return v;
288+
}
289+
static __m512i cast_to(reg_t v){
290+
return v;
291+
}
284292
};
285293
template <>
286294
struct zmm_vector<uint16_t> {
@@ -291,6 +299,8 @@ struct zmm_vector<uint16_t> {
291299
static const uint8_t numlanes = 32;
292300
static constexpr int network_sort_threshold = 512;
293301
static constexpr int partition_unroll_factor = 0;
302+
303+
using swizzle_ops = avx512_16bit_swizzle_ops;
294304

295305
static reg_t get_network(int index)
296306
{
@@ -384,14 +394,16 @@ struct zmm_vector<uint16_t> {
384394
const auto rev_index = get_network(4);
385395
return permutexvar(rev_index, zmm);
386396
}
387-
static reg_t bitonic_merge(reg_t x)
388-
{
389-
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
390-
}
391397
static reg_t sort_vec(reg_t x)
392398
{
393399
return sort_zmm_16bit<zmm_vector<type_t>>(x);
394400
}
401+
static reg_t cast_from(__m512i v){
402+
return v;
403+
}
404+
static __m512i cast_to(reg_t v){
405+
return v;
406+
}
395407
};
396408

397409
template <>

src/avx512-32bit-qsort.hpp

Lines changed: 86 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
template <typename vtype, typename reg_t>
2828
X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm);
2929

30-
template <typename vtype, typename reg_t>
31-
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm);
30+
struct avx512_32bit_swizzle_ops;
3231

3332
template <>
3433
struct zmm_vector<int32_t> {
@@ -39,6 +38,8 @@ struct zmm_vector<int32_t> {
3938
static const uint8_t numlanes = 16;
4039
static constexpr int network_sort_threshold = 256;
4140
static constexpr int partition_unroll_factor = 2;
41+
42+
using swizzle_ops = avx512_32bit_swizzle_ops;
4243

4344
static type_t type_max()
4445
{
@@ -138,14 +139,16 @@ struct zmm_vector<int32_t> {
138139
const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5);
139140
return permutexvar(rev_index, zmm);
140141
}
141-
static reg_t bitonic_merge(reg_t x)
142-
{
143-
return bitonic_merge_zmm_32bit<zmm_vector<type_t>>(x);
144-
}
145142
static reg_t sort_vec(reg_t x)
146143
{
147144
return sort_zmm_32bit<zmm_vector<type_t>>(x);
148145
}
146+
static reg_t cast_from(__m512i v){
147+
return v;
148+
}
149+
static __m512i cast_to(reg_t v){
150+
return v;
151+
}
149152
};
150153
template <>
151154
struct zmm_vector<uint32_t> {
@@ -156,6 +159,8 @@ struct zmm_vector<uint32_t> {
156159
static const uint8_t numlanes = 16;
157160
static constexpr int network_sort_threshold = 256;
158161
static constexpr int partition_unroll_factor = 2;
162+
163+
using swizzle_ops = avx512_32bit_swizzle_ops;
159164

160165
static type_t type_max()
161166
{
@@ -255,14 +260,16 @@ struct zmm_vector<uint32_t> {
255260
const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5);
256261
return permutexvar(rev_index, zmm);
257262
}
258-
static reg_t bitonic_merge(reg_t x)
259-
{
260-
return bitonic_merge_zmm_32bit<zmm_vector<type_t>>(x);
261-
}
262263
static reg_t sort_vec(reg_t x)
263264
{
264265
return sort_zmm_32bit<zmm_vector<type_t>>(x);
265266
}
267+
static reg_t cast_from(__m512i v){
268+
return v;
269+
}
270+
static __m512i cast_to(reg_t v){
271+
return v;
272+
}
266273
};
267274
template <>
268275
struct zmm_vector<float> {
@@ -273,6 +280,8 @@ struct zmm_vector<float> {
273280
static const uint8_t numlanes = 16;
274281
static constexpr int network_sort_threshold = 256;
275282
static constexpr int partition_unroll_factor = 2;
283+
284+
using swizzle_ops = avx512_32bit_swizzle_ops;
276285

277286
static type_t type_max()
278287
{
@@ -386,14 +395,16 @@ struct zmm_vector<float> {
386395
const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5);
387396
return permutexvar(rev_index, zmm);
388397
}
389-
static reg_t bitonic_merge(reg_t x)
390-
{
391-
return bitonic_merge_zmm_32bit<zmm_vector<type_t>>(x);
392-
}
393398
static reg_t sort_vec(reg_t x)
394399
{
395400
return sort_zmm_32bit<zmm_vector<type_t>>(x);
396401
}
402+
static reg_t cast_from(__m512i v){
403+
return _mm512_castsi512_ps(v);
404+
}
405+
static __m512i cast_to(reg_t v){
406+
return _mm512_castps_si512(v);
407+
}
397408
};
398409

399410
/*
@@ -446,31 +457,66 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm)
446457
return zmm;
447458
}
448459

449-
// Assumes zmm is bitonic and performs a recursive half cleaner
450-
template <typename vtype, typename reg_t = typename vtype::reg_t>
451-
X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit(reg_t zmm)
452-
{
453-
// 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
454-
zmm = cmp_merge<vtype>(
455-
zmm,
456-
vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_7), zmm),
457-
0xFF00);
458-
// 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc ..
459-
zmm = cmp_merge<vtype>(
460-
zmm,
461-
vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm),
462-
0xF0F0);
463-
// 3) half_cleaner[4]
464-
zmm = cmp_merge<vtype>(
465-
zmm,
466-
vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(zmm),
467-
0xCCCC);
468-
// 3) half_cleaner[1]
469-
zmm = cmp_merge<vtype>(
470-
zmm,
471-
vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(zmm),
472-
0xAAAA);
473-
return zmm;
474-
}
460+
struct avx512_32bit_swizzle_ops{
461+
template <typename vtype, int scale>
462+
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){
463+
__m512i v = vtype::cast_to(reg);
464+
465+
if constexpr (scale == 2){
466+
v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b10110001);
467+
}else if constexpr (scale == 4){
468+
v = _mm512_shuffle_epi32(v, (_MM_PERM_ENUM)0b01001110);
469+
}else if constexpr (scale == 8){
470+
v = _mm512_shuffle_i64x2(v, v, 0b10110001);
471+
}else if constexpr (scale == 16){
472+
v = _mm512_shuffle_i64x2(v, v, 0b01001110);
473+
}else{
474+
static_assert(scale == -1, "should not be reached");
475+
}
476+
477+
return vtype::cast_from(v);
478+
}
479+
480+
template <typename vtype, int scale>
481+
X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n(typename vtype::reg_t reg){
482+
__m512i v = vtype::cast_to(reg);
483+
484+
if constexpr (scale == 2){
485+
return swap_n<vtype, 2>(reg);
486+
}else if constexpr (scale == 4){
487+
__m512i mask = _mm512_set_epi32(12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3);
488+
v = _mm512_permutexvar_epi32(mask, v);
489+
}else if constexpr (scale == 8){
490+
__m512i mask = _mm512_set_epi32(8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7);
491+
v = _mm512_permutexvar_epi32(mask, v);
492+
}else if constexpr (scale == 16){
493+
return vtype::reverse(reg);
494+
}else{
495+
static_assert(scale == -1, "should not be reached");
496+
}
497+
498+
return vtype::cast_from(v);
499+
}
500+
501+
template <typename vtype, int scale>
502+
X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n(typename vtype::reg_t reg, typename vtype::reg_t other){
503+
__m512i v1 = vtype::cast_to(reg);
504+
__m512i v2 = vtype::cast_to(other);
505+
506+
if constexpr (scale == 2){
507+
v1 = _mm512_mask_blend_epi32(0b0101010101010101, v1, v2);
508+
}else if constexpr (scale == 4){
509+
v1 = _mm512_mask_blend_epi32(0b0011001100110011, v1, v2);
510+
}else if constexpr (scale == 8){
511+
v1 = _mm512_mask_blend_epi32(0b0000111100001111, v1, v2);
512+
}else if constexpr (scale == 16){
513+
v1 = _mm512_mask_blend_epi32(0b0000000011111111, v1, v2);
514+
}else{
515+
static_assert(scale == -1, "should not be reached");
516+
}
517+
518+
return vtype::cast_from(v1);
519+
}
520+
};
475521

476522
#endif //AVX512_QSORT_32BIT

0 commit comments

Comments
 (0)