|
8 | 8 | #define AVX512_16BIT_COMMON |
9 | 9 |
|
10 | 10 | #include "avx512-common-qsort.h" |
| 11 | +#include "xss-network-qsort.hpp" |
11 | 12 |
|
12 | 13 | /* |
13 | 14 | * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic |
@@ -118,103 +119,6 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm) |
118 | 119 | return zmm; |
119 | 120 | } |
120 | 121 |
|
121 | | -// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner |
122 | | -template <typename vtype, typename zmm_t = typename vtype::zmm_t> |
123 | | -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2) |
124 | | -{ |
125 | | - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed |
126 | | - zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2); |
127 | | - zmm_t zmm3 = vtype::min(zmm1, zmm2); |
128 | | - zmm_t zmm4 = vtype::max(zmm1, zmm2); |
129 | | - // 2) Recursive half cleaner for each |
130 | | - zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3); |
131 | | - zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4); |
132 | | -} |
133 | | - |
134 | | -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive |
135 | | -// half cleaner |
136 | | -template <typename vtype, typename zmm_t = typename vtype::zmm_t> |
137 | | -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm) |
138 | | -{ |
139 | | - zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]); |
140 | | - zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]); |
141 | | - zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r); |
142 | | - zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r); |
143 | | - zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4), |
144 | | - vtype::max(zmm[1], zmm2r)); |
145 | | - zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4), |
146 | | - vtype::max(zmm[0], zmm3r)); |
147 | | - zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2); |
148 | | - zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2); |
149 | | - zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4); |
150 | | - zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4); |
151 | | - zmm[0] = bitonic_merge_zmm_16bit<vtype>(zmm0); |
152 | | - zmm[1] = bitonic_merge_zmm_16bit<vtype>(zmm1); |
153 | | - zmm[2] = bitonic_merge_zmm_16bit<vtype>(zmm2); |
154 | | - zmm[3] = bitonic_merge_zmm_16bit<vtype>(zmm3); |
155 | | -} |
156 | | - |
157 | | -template <typename vtype, typename type_t> |
158 | | -X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N) |
159 | | -{ |
160 | | - typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF; |
161 | | - typename vtype::zmm_t zmm |
162 | | - = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr); |
163 | | - vtype::mask_storeu(arr, load_mask, sort_zmm_16bit<vtype>(zmm)); |
164 | | -} |
165 | | - |
166 | | -template <typename vtype, typename type_t> |
167 | | -X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N) |
168 | | -{ |
169 | | - if (N <= 32) { |
170 | | - sort_32_16bit<vtype>(arr, N); |
171 | | - return; |
172 | | - } |
173 | | - using zmm_t = typename vtype::zmm_t; |
174 | | - typename vtype::opmask_t load_mask |
175 | | - = ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF; |
176 | | - zmm_t zmm1 = vtype::loadu(arr); |
177 | | - zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32); |
178 | | - zmm1 = sort_zmm_16bit<vtype>(zmm1); |
179 | | - zmm2 = sort_zmm_16bit<vtype>(zmm2); |
180 | | - bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2); |
181 | | - vtype::storeu(arr, zmm1); |
182 | | - vtype::mask_storeu(arr + 32, load_mask, zmm2); |
183 | | -} |
184 | | - |
185 | | -template <typename vtype, typename type_t> |
186 | | -X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N) |
187 | | -{ |
188 | | - if (N <= 64) { |
189 | | - sort_64_16bit<vtype>(arr, N); |
190 | | - return; |
191 | | - } |
192 | | - using zmm_t = typename vtype::zmm_t; |
193 | | - using opmask_t = typename vtype::opmask_t; |
194 | | - zmm_t zmm[4]; |
195 | | - zmm[0] = vtype::loadu(arr); |
196 | | - zmm[1] = vtype::loadu(arr + 32); |
197 | | - opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF; |
198 | | - if (N != 128) { |
199 | | - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; |
200 | | - load_mask1 = combined_mask & 0xFFFFFFFF; |
201 | | - load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF; |
202 | | - } |
203 | | - zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64); |
204 | | - zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96); |
205 | | - zmm[0] = sort_zmm_16bit<vtype>(zmm[0]); |
206 | | - zmm[1] = sort_zmm_16bit<vtype>(zmm[1]); |
207 | | - zmm[2] = sort_zmm_16bit<vtype>(zmm[2]); |
208 | | - zmm[3] = sort_zmm_16bit<vtype>(zmm[3]); |
209 | | - bitonic_merge_two_zmm_16bit<vtype>(zmm[0], zmm[1]); |
210 | | - bitonic_merge_two_zmm_16bit<vtype>(zmm[2], zmm[3]); |
211 | | - bitonic_merge_four_zmm_16bit<vtype>(zmm); |
212 | | - vtype::storeu(arr, zmm[0]); |
213 | | - vtype::storeu(arr + 32, zmm[1]); |
214 | | - vtype::mask_storeu(arr + 64, load_mask1, zmm[2]); |
215 | | - vtype::mask_storeu(arr + 96, load_mask2, zmm[3]); |
216 | | -} |
217 | | - |
218 | 122 | template <typename vtype, typename type_t> |
219 | 123 | X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, |
220 | 124 | const int64_t left, |
@@ -274,7 +178,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) |
274 | 178 | * Base case: use bitonic networks to sort arrays <= 128 |
275 | 179 | */ |
276 | 180 | if (right + 1 - left <= 128) { |
277 | | - sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left)); |
| 181 | + xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left)); |
278 | 182 | return; |
279 | 183 | } |
280 | 184 |
|
@@ -307,7 +211,7 @@ static void qselect_16bit_(type_t *arr, |
307 | 211 | * Base case: use bitonic networks to sort arrays <= 128 |
308 | 212 | */ |
309 | 213 | if (right + 1 - left <= 128) { |
310 | | - sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left)); |
| 214 | + xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left)); |
311 | 215 | return; |
312 | 216 | } |
313 | 217 |
|
|
0 commit comments