2222#define  NETWORK_32BIT_AVX2_3  5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 
2323#define  NETWORK_32BIT_AVX2_4  3 , 2 , 1 , 0 , 7 , 6 , 5 , 4 
2424
25- namespace  xss  {
26- namespace  avx2  {
27- 
28- //  Assumes ymm is bitonic and performs a recursive half cleaner
29- template  <typename  vtype, typename  reg_t  = typename  vtype::reg_t >
30- X86_SIMD_SORT_INLINE reg_t  bitonic_merge_ymm_32bit (reg_t  ymm)
31- {
32- 
33-     const  typename  vtype::opmask_t  oxAA = _mm256_set_epi32 (
34-             0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 );
35-     const  typename  vtype::opmask_t  oxCC = _mm256_set_epi32 (
36-             0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 , 0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 );
37-     const  typename  vtype::opmask_t  oxF0 = _mm256_set_epi32 (
38-             0xFFFFFFFF , 0xFFFFFFFF , 0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 , 0 , 0 );
39- 
40-     //  1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
41-     ymm = cmp_merge<vtype>(
42-             ymm,
43-             vtype::permutexvar (_mm256_set_epi32 (NETWORK_32BIT_AVX2_4), ymm),
44-             oxF0);
45-     //  2) half_cleaner[4]
46-     ymm = cmp_merge<vtype>(
47-             ymm,
48-             vtype::permutexvar (_mm256_set_epi32 (NETWORK_32BIT_AVX2_3), ymm),
49-             oxCC);
50-     //  3) half_cleaner[1]
51-     ymm = cmp_merge<vtype>(
52-             ymm, vtype::template  shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(ymm), oxAA);
53-     return  ymm;
54- }
55- 
5625/* 
5726 * Assumes ymm is random and performs a full sorting network defined in 
5827 * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg 
@@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
8554struct  avx2_32bit_swizzle_ops ;
8655
8756template  <>
88- struct  ymm_vector <int32_t > {
57+ struct  avx2_vector <int32_t > {
8958    using  type_t  = int32_t ;
9059    using  reg_t  = __m256i;
9160    using  ymmi_t  = __m256i;
@@ -231,13 +200,9 @@ struct ymm_vector<int32_t> {
231200    {
232201        _mm256_storeu_si256 ((__m256i *)mem, x);
233202    }
234-     static  reg_t  bitonic_merge (reg_t  x)
235-     {
236-         return  bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
237-     }
238203    static  reg_t  sort_vec (reg_t  x)
239204    {
240-         return  sort_ymm_32bit<ymm_vector <type_t >>(x);
205+         return  sort_ymm_32bit<avx2_vector <type_t >>(x);
241206    }
242207    static  reg_t  cast_from (__m256i v){
243208        return  v;
@@ -247,7 +212,7 @@ struct ymm_vector<int32_t> {
247212    }
248213};
249214template  <>
250- struct  ymm_vector <uint32_t > {
215+ struct  avx2_vector <uint32_t > {
251216    using  type_t  = uint32_t ;
252217    using  reg_t  = __m256i;
253218    using  ymmi_t  = __m256i;
@@ -378,13 +343,9 @@ struct ymm_vector<uint32_t> {
378343    {
379344        _mm256_storeu_si256 ((__m256i *)mem, x);
380345    }
381-     static  reg_t  bitonic_merge (reg_t  x)
382-     {
383-         return  bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
384-     }
385346    static  reg_t  sort_vec (reg_t  x)
386347    {
387-         return  sort_ymm_32bit<ymm_vector <type_t >>(x);
348+         return  sort_ymm_32bit<avx2_vector <type_t >>(x);
388349    }
389350    static  reg_t  cast_from (__m256i v){
390351        return  v;
@@ -394,7 +355,7 @@ struct ymm_vector<uint32_t> {
394355    }
395356};
396357template  <>
397- struct  ymm_vector <float > {
358+ struct  avx2_vector <float > {
398359    using  type_t  = float ;
399360    using  reg_t  = __m256;
400361    using  ymmi_t  = __m256i;
@@ -440,6 +401,19 @@ struct ymm_vector<float> {
440401    {
441402        return  _mm256_castps_si256 (_mm256_cmp_ps (x, y, _CMP_EQ_OQ));
442403    }
404+     static  opmask_t  get_partial_loadmask (int  size)
405+     {
406+         return  (0x0001  << size) - 0x0001 ;
407+     }
408+     template  <int  type>
409+     static  opmask_t  fpclass (reg_t  x)
410+     {
411+         if  constexpr  (type == (0x01  | 0x80 )){
412+             return  _mm256_castps_si256 (_mm256_cmp_ps (x, x, _CMP_UNORD_Q));
413+         }else {
414+             static_assert (type == (0x01  | 0x80 ), " should not reach here" 
415+         }
416+     }
443417    template  <int  scale>
444418    static  reg_t 
445419    mask_i64gather (reg_t  src, opmask_t  mask, __m256i index, void  const  *base)
@@ -533,13 +507,9 @@ struct ymm_vector<float> {
533507    {
534508        _mm256_storeu_ps ((float  *)mem, x);
535509    }
536-     static  reg_t  bitonic_merge (reg_t  x)
537-     {
538-         return  bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
539-     }
540510    static  reg_t  sort_vec (reg_t  x)
541511    {
542-         return  sort_ymm_32bit<ymm_vector <type_t >>(x);
512+         return  sort_ymm_32bit<avx2_vector <type_t >>(x);
543513    }
544514    static  reg_t  cast_from (__m256i v){
545515        return  _mm256_castsi256_ps (v);
@@ -549,32 +519,6 @@ struct ymm_vector<float> {
549519    }
550520};
551521
552- inline  arrsize_t  replace_nan_with_inf (float  *arr, int64_t  arrsize)
553- {
554-     arrsize_t  nan_count = 0 ;
555-     __mmask8 loadmask = 0xFF ;
556-     while  (arrsize > 0 ) {
557-         if  (arrsize < 8 ) { loadmask = (0x01  << arrsize) - 0x01 ; }
558-         __m256 in_ymm = ymm_vector<float >::maskz_loadu (loadmask, arr);
559-         __m256i nanmask = _mm256_castps_si256 (
560-                 _mm256_cmp_ps (in_ymm, in_ymm, _CMP_NEQ_UQ));
561-         nan_count += _mm_popcnt_u32 (avx2_mask_helper32 (nanmask));
562-         ymm_vector<float >::mask_storeu (arr, nanmask, YMM_MAX_FLOAT);
563-         arr += 8 ;
564-         arrsize -= 8 ;
565-     }
566-     return  nan_count;
567- }
568- 
569- X86_SIMD_SORT_INLINE void 
570- replace_inf_with_nan (float  *arr, arrsize_t  arrsize, arrsize_t  nan_count)
571- {
572-     for  (arrsize_t  ii = arrsize - 1 ; nan_count > 0 ; --ii) {
573-         arr[ii] = std::nan (" 1" 
574-         nan_count -= 1 ;
575-     }
576- }
577- 
578522struct  avx2_32bit_swizzle_ops {
579523    template  <typename  vtype, int  scale>
580524    X86_SIMD_SORT_INLINE typename  vtype::reg_t  swap_n (typename  vtype::reg_t  reg){
@@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{
635579        return  vtype::cast_from (v1);
636580    }
637581};
638- 
639- } //  namespace avx2
640- } //  namespace xss
641582#endif 
0 commit comments