2727template  <typename  vtype, typename  reg_t >
2828X86_SIMD_SORT_INLINE reg_t  sort_zmm_32bit (reg_t  zmm);
2929
30- template  <typename  vtype, typename  reg_t >
31- X86_SIMD_SORT_INLINE reg_t  bitonic_merge_zmm_32bit (reg_t  zmm);
30+ struct  avx512_32bit_swizzle_ops ;
3231
3332template  <>
3433struct  zmm_vector <int32_t > {
@@ -39,6 +38,8 @@ struct zmm_vector<int32_t> {
3938    static  const  uint8_t  numlanes = 16 ;
4039    static  constexpr  int  network_sort_threshold = 256 ;
4140    static  constexpr  int  partition_unroll_factor = 2 ;
41+     
42+     using  swizzle_ops = avx512_32bit_swizzle_ops;
4243
4344    static  type_t  type_max ()
4445    {
@@ -138,14 +139,16 @@ struct zmm_vector<int32_t> {
138139        const  auto  rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
139140        return  permutexvar (rev_index, zmm);
140141    }
141-     static  reg_t  bitonic_merge (reg_t  x)
142-     {
143-         return  bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
144-     }
145142    static  reg_t  sort_vec (reg_t  x)
146143    {
147144        return  sort_zmm_32bit<zmm_vector<type_t >>(x);
148145    }
146+     static  reg_t  cast_from (__m512i v){
147+         return  v;
148+     }
149+     static  __m512i cast_to (reg_t  v){
150+         return  v;
151+     }
149152};
150153template  <>
151154struct  zmm_vector <uint32_t > {
@@ -156,6 +159,8 @@ struct zmm_vector<uint32_t> {
156159    static  const  uint8_t  numlanes = 16 ;
157160    static  constexpr  int  network_sort_threshold = 256 ;
158161    static  constexpr  int  partition_unroll_factor = 2 ;
162+     
163+     using  swizzle_ops = avx512_32bit_swizzle_ops;
159164
160165    static  type_t  type_max ()
161166    {
@@ -255,14 +260,16 @@ struct zmm_vector<uint32_t> {
255260        const  auto  rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
256261        return  permutexvar (rev_index, zmm);
257262    }
258-     static  reg_t  bitonic_merge (reg_t  x)
259-     {
260-         return  bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
261-     }
262263    static  reg_t  sort_vec (reg_t  x)
263264    {
264265        return  sort_zmm_32bit<zmm_vector<type_t >>(x);
265266    }
267+     static  reg_t  cast_from (__m512i v){
268+         return  v;
269+     }
270+     static  __m512i cast_to (reg_t  v){
271+         return  v;
272+     }
266273};
267274template  <>
268275struct  zmm_vector <float > {
@@ -273,6 +280,8 @@ struct zmm_vector<float> {
273280    static  const  uint8_t  numlanes = 16 ;
274281    static  constexpr  int  network_sort_threshold = 256 ;
275282    static  constexpr  int  partition_unroll_factor = 2 ;
283+     
284+     using  swizzle_ops = avx512_32bit_swizzle_ops;
276285
277286    static  type_t  type_max ()
278287    {
@@ -386,14 +395,16 @@ struct zmm_vector<float> {
386395        const  auto  rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
387396        return  permutexvar (rev_index, zmm);
388397    }
389-     static  reg_t  bitonic_merge (reg_t  x)
390-     {
391-         return  bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
392-     }
393398    static  reg_t  sort_vec (reg_t  x)
394399    {
395400        return  sort_zmm_32bit<zmm_vector<type_t >>(x);
396401    }
402+     static  reg_t  cast_from (__m512i v){
403+         return  _mm512_castsi512_ps (v);
404+     }
405+     static  __m512i cast_to (reg_t  v){
406+         return  _mm512_castps_si512 (v);
407+     }
397408};
398409
399410/* 
@@ -446,31 +457,66 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm)
446457    return  zmm;
447458}
448459
449- //  Assumes zmm is bitonic and performs a recursive half cleaner
450- template  <typename  vtype, typename  reg_t  = typename  vtype::reg_t >
451- X86_SIMD_SORT_INLINE reg_t  bitonic_merge_zmm_32bit (reg_t  zmm)
452- {
453-     //  1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
454-     zmm = cmp_merge<vtype>(
455-             zmm,
456-             vtype::permutexvar (_mm512_set_epi32 (NETWORK_32BIT_7), zmm),
457-             0xFF00 );
458-     //  2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc ..
459-     zmm = cmp_merge<vtype>(
460-             zmm,
461-             vtype::permutexvar (_mm512_set_epi32 (NETWORK_32BIT_6), zmm),
462-             0xF0F0 );
463-     //  3) half_cleaner[4]
464-     zmm = cmp_merge<vtype>(
465-             zmm,
466-             vtype::template  shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(zmm),
467-             0xCCCC );
468-     //  3) half_cleaner[1]
469-     zmm = cmp_merge<vtype>(
470-             zmm,
471-             vtype::template  shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(zmm),
472-             0xAAAA );
473-     return  zmm;
474- }
460+ struct  avx512_32bit_swizzle_ops {
461+     template  <typename  vtype, int  scale>
462+     X86_SIMD_SORT_INLINE typename  vtype::reg_t  swap_n (typename  vtype::reg_t  reg){
463+         __m512i v = vtype::cast_to (reg);
464+         
465+         if  constexpr  (scale == 2 ){
466+             v = _mm512_shuffle_epi32 (v, (_MM_PERM_ENUM)0b10110001 );
467+         }else  if  constexpr  (scale == 4 ){
468+             v = _mm512_shuffle_epi32 (v, (_MM_PERM_ENUM)0b01001110 );
469+         }else  if  constexpr  (scale == 8 ){
470+             v = _mm512_shuffle_i64x2 (v, v, 0b10110001 );
471+         }else  if  constexpr  (scale == 16 ){
472+             v = _mm512_shuffle_i64x2 (v, v, 0b01001110 );
473+         }else {
474+             static_assert (scale == -1 , " should not be reached" 
475+         }
476+         
477+         return  vtype::cast_from (v);
478+     } 
479+     
480+     template  <typename  vtype, int  scale>
481+     X86_SIMD_SORT_INLINE typename  vtype::reg_t  reverse_n (typename  vtype::reg_t  reg){
482+         __m512i v = vtype::cast_to (reg);
483+         
484+         if  constexpr  (scale == 2 ){
485+             return  swap_n<vtype, 2 >(reg);
486+         }else  if  constexpr  (scale == 4 ){
487+             __m512i mask = _mm512_set_epi32 (12 ,13 ,14 ,15 ,8 ,9 ,10 ,11 ,4 ,5 ,6 ,7 ,0 ,1 ,2 ,3 );
488+             v = _mm512_permutexvar_epi32 (mask, v);
489+         }else  if  constexpr  (scale == 8 ){
490+             __m512i mask = _mm512_set_epi32 (8 ,9 ,10 ,11 ,12 ,13 ,14 ,15 ,0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 );
491+             v = _mm512_permutexvar_epi32 (mask, v);
492+         }else  if  constexpr  (scale == 16 ){
493+             return  vtype::reverse (reg);
494+         }else {
495+             static_assert (scale == -1 , " should not be reached" 
496+         }
497+         
498+         return  vtype::cast_from (v);
499+     }
500+     
501+     template  <typename  vtype, int  scale>
502+     X86_SIMD_SORT_INLINE typename  vtype::reg_t  merge_n (typename  vtype::reg_t  reg, typename  vtype::reg_t  other){
503+         __m512i v1 = vtype::cast_to (reg);
504+         __m512i v2 = vtype::cast_to (other);
505+         
506+         if  constexpr  (scale == 2 ){
507+             v1 = _mm512_mask_blend_epi32 (0b0101010101010101 , v1, v2);
508+         }else  if  constexpr  (scale == 4 ){
509+             v1 = _mm512_mask_blend_epi32 (0b0011001100110011 , v1, v2);
510+         }else  if  constexpr  (scale == 8 ){
511+             v1 = _mm512_mask_blend_epi32 (0b0000111100001111 , v1, v2);
512+         }else  if  constexpr  (scale == 16 ){
513+             v1 = _mm512_mask_blend_epi32 (0b0000000011111111 , v1, v2);
514+         }else {
515+             static_assert (scale == -1 , " should not be reached" 
516+         }
517+         
518+         return  vtype::cast_from (v1);
519+     } 
520+ };
475521
476522#endif  // AVX512_QSORT_32BIT
0 commit comments