8585#define  X86_SIMD_SORT_FINLINE  static 
8686#endif 
8787
88+ #define  LIKELY (x )       __builtin_expect((x),1 )
89+ #define  UNLIKELY (x )     __builtin_expect((x),0 )
90+ 
8891template  <typename  type>
8992struct  zmm_vector ;
9093
@@ -97,25 +100,54 @@ void avx512_qsort(T *arr, int64_t arrsize);
97100void  avx512_qsort_fp16 (uint16_t  *arr, int64_t  arrsize);
98101
99102template  <typename  T>
100- void  avx512_qselect (T *arr, int64_t  k, int64_t  arrsize);
101- void  avx512_qselect_fp16 (uint16_t  *arr, int64_t  k, int64_t  arrsize);
103+ void  avx512_qselect (T *arr, int64_t  k, int64_t  arrsize,  bool  hasnan =  false );
104+ void  avx512_qselect_fp16 (uint16_t  *arr, int64_t  k, int64_t  arrsize,  bool  hasnan =  false );
102105
103106template  <typename  T>
104- inline  void  avx512_partial_qsort (T *arr, int64_t  k, int64_t  arrsize)
107+ inline  void  avx512_partial_qsort (T *arr, int64_t  k, int64_t  arrsize,  bool  hasnan =  false )
105108{
106-     avx512_qselect<T>(arr, k - 1 , arrsize);
109+     avx512_qselect<T>(arr, k - 1 , arrsize, hasnan );
107110    avx512_qsort<T>(arr, k - 1 );
108111}
109- inline  void  avx512_partial_qsort_fp16 (uint16_t  *arr, int64_t  k, int64_t  arrsize)
112+ inline  void  avx512_partial_qsort_fp16 (uint16_t  *arr, int64_t  k, int64_t  arrsize,  bool  hasnan =  false )
110113{
111-     avx512_qselect_fp16 (arr, k - 1 , arrsize);
114+     avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan );
112115    avx512_qsort_fp16 (arr, k - 1 );
113116}
114117
115118//  key-value sort routines
116119template  <typename  T>
117120void  avx512_qsort_kv (T *keys, uint64_t  *indexes, int64_t  arrsize);
118121
122+ template  <typename  T>
123+ bool  is_a_nan (T elem)
124+ {
125+     return  std::isnan (elem);
126+ }
127+ 
128+ /* 
129+  * Sort all the NAN's to end of the array and return the index of the last elem 
130+  * in the array which is not a nan 
131+  */  
132+ template  <typename  T>
133+ int64_t  move_nans_to_end_of_array (T* arr, int64_t  arrsize)
134+ {
135+     int64_t  jj = arrsize - 1 ;
136+     int64_t  ii = 0 ;
137+     int64_t  count = 0 ;
138+     while  (ii <= jj) {
139+         if  (is_a_nan (arr[ii])) {
140+             std::swap (arr[ii], arr[jj]);
141+             jj -= 1 ;
142+             count++;
143+         }
144+         else  {
145+             ii += 1 ;
146+         }
147+     }
148+     return  arrsize-count-1 ;
149+ }
150+ 
119151template  <typename  vtype, typename  T = typename  vtype::type_t >
120152bool  comparison_func (const  T &a, const  T &b)
121153{
0 commit comments