55#include  " xss-optimal-networks.hpp" 
66
77template  <typename  vtype,
8-           int64_t  numVecs,
9-           typename  reg_t  = typename  vtype::reg_t >
10- X86_SIMD_SORT_FINLINE void  bitonic_clean_n_vec (reg_t  *regs)
11- {
12-     X86_SIMD_SORT_UNROLL_LOOP (512 )
13-     for  (int  num = numVecs / 2 ; num >= 2 ; num /= 2 ) {
14-         X86_SIMD_SORT_UNROLL_LOOP (512 )
15-         for  (int  j = 0 ; j < numVecs; j += num) {
16-             X86_SIMD_SORT_UNROLL_LOOP (512 )
17-             for  (int  i = 0 ; i < num / 2 ; i++) {
18-                 COEX<vtype>(regs[i + j], regs[i + j + num / 2 ]);
19-             }
20-         }
21-     }
22- }
23- 
24- template  <typename  vtype,
25-           int64_t  numVecs,
8+           int  numVecs,
269          typename  reg_t  = typename  vtype::reg_t >
2710X86_SIMD_SORT_FINLINE void  bitonic_sort_n_vec (reg_t  *regs)
2811{
@@ -46,20 +29,11 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs)
4629        optimal_sort_32<vtype>(regs);
4730    }
4831    else  {
49-         //  TODO should we remove this branch? I believe it is never used in the current code
50-         bitonic_sort_n_vec<vtype, numVecs / 2 >(regs);
51-         bitonic_sort_n_vec<vtype, numVecs / 2 >(regs + numVecs / 2 );
52- 
53-         X86_SIMD_SORT_UNROLL_LOOP (64 )
54-         for  (int  i = 0 ; i < numVecs / 2 ; i++) {
55-             COEX<vtype>(regs[i], regs[numVecs - 1  - i]);
56-         }
57- 
58-         bitonic_clean_n_vec<vtype, numVecs>(regs);
32+         static_assert (numVecs == -1 , " should not reach here"  );
5933    }
6034}
6135
62- template  <typename  vtype, int64_t  numVecs, int64_t  scale, bool  first = true >
36+ template  <typename  vtype, int  numVecs, int  scale, bool  first = true >
6337X86_SIMD_SORT_FINLINE void  internal_merge_n_vec (typename  vtype::reg_t  *reg)
6438{
6539    using  reg_t  = typename  vtype::reg_t ;
@@ -94,8 +68,8 @@ X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg)
9468}
9569
9670template  <typename  vtype,
97-           int64_t  numVecs,
98-           int64_t  scale,
71+           int  numVecs,
72+           int  scale,
9973          typename  reg_t  = typename  vtype::reg_t >
10074X86_SIMD_SORT_FINLINE void  merge_substep_n_vec (reg_t  *regs)
10175{
@@ -121,8 +95,8 @@ X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs)
12195}
12296
12397template  <typename  vtype,
124-           int64_t  numVecs,
125-           int64_t  scale,
98+           int  numVecs,
99+           int  scale,
126100          typename  reg_t  = typename  vtype::reg_t >
127101X86_SIMD_SORT_FINLINE void  merge_step_n_vec (reg_t  *regs)
128102{
@@ -134,8 +108,8 @@ X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs)
134108}
135109
136110template  <typename  vtype,
137-           int64_t  numVecs,
138-           int64_t  numPer = 2 ,
111+           int  numVecs,
112+           int  numPer = 2 ,
139113          typename  reg_t  = typename  vtype::reg_t >
140114X86_SIMD_SORT_FINLINE void  merge_n_vec (reg_t  *regs)
141115{
@@ -216,22 +190,22 @@ X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N)
216190
217191template  <typename  vtype, typename  type_t >
218192X86_SIMD_SORT_INLINE type_t  get_pivot (type_t  *arr,
219-                                       const  int64_t  left,
220-                                       const  int64_t  right);
193+                                       const  arrsize_t  left,
194+                                       const  arrsize_t  right);
221195
222196template  <typename  vtype, typename  type_t >
223197X86_SIMD_SORT_INLINE type_t  get_pivot_blocks (type_t  *arr,
224-                                              uint64_t  left,
225-                                              uint64_t  right)
198+                                              arrsize_t  left,
199+                                              arrsize_t  right)
226200{
227201
228202    if  (right - left <= 1024 ) { return  get_pivot<vtype>(arr, left, right); }
229203
230204    using  reg_t  = typename  vtype::reg_t ;
231205    constexpr  int  numVecs = 5 ;
232206
233-     uint64_t  width = (right - vtype::numlanes) - left;
234-     uint64_t  delta = width / numVecs;
207+     arrsize_t  width = (right - vtype::numlanes) - left;
208+     arrsize_t  delta = width / numVecs;
235209
236210    reg_t  vecs[numVecs];
237211    //  Load data
0 commit comments