Skip to content

Commit 1f6100d

Browse files
committed
Minor cleanup changes
1 parent 84592b2 commit 1f6100d

File tree

3 files changed

+28
-46
lines changed

3 files changed

+28
-46
lines changed

src/avx2-32bit-half.hpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,26 @@
2929
template <typename vtype, typename reg_t = typename vtype::reg_t>
3030
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm)
3131
{
32-
//static_assert(vtype::numlanes == 0, "This function is not implemented");
33-
typename vtype::type_t buffer[vtype::numlanes];
34-
vtype::storeu(buffer, ymm);
35-
std::sort(&buffer[0], &buffer[vtype::numlanes], comparison_func<vtype>);
36-
return vtype::loadu(buffer);
37-
/*
38-
const typename vtype::opmask_t oxAA = _mm256_set_epi32(
39-
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
40-
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
41-
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
42-
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
43-
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);
44-
45-
const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2);
46-
ymm = cmp_merge<vtype>(
47-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
32+
using swizzle = typename vtype::swizzle_ops;
33+
34+
const typename vtype::opmask_t oxAA
35+
= vtype::seti(-1, 0, -1, 0);
36+
const typename vtype::opmask_t oxCC
37+
= vtype::seti(-1, -1, 0, 0);
38+
4839
ymm = cmp_merge<vtype>(
4940
ymm,
50-
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm),
51-
oxCC);
52-
ymm = cmp_merge<vtype>(
53-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
54-
ymm = cmp_merge<vtype>(ymm, vtype::permutexvar(rev_index, ymm), oxF0);
41+
swizzle::template swap_n<vtype, 2>(ymm),
42+
oxAA);
5543
ymm = cmp_merge<vtype>(
5644
ymm,
57-
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm),
45+
vtype::reverse(ymm),
5846
oxCC);
5947
ymm = cmp_merge<vtype>(
60-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
48+
ymm,
49+
swizzle::template swap_n<vtype, 2>(ymm),
50+
oxAA);
6151
return ymm;
62-
*/
6352
}
6453

6554
struct avx2_32bit_half_swizzle_ops;

src/avx512-64bit-argsort.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,9 @@
88
#define AVX512_ARGSORT_64BIT
99

1010
#include "xss-common-qsort.h"
11-
//#include "avx512-64bit-common.h"
12-
//#include "avx2-32bit-half.hpp"
1311
#include "xss-network-keyvaluesort.hpp"
1412
#include <numeric>
1513

16-
template <typename T>
17-
struct avx2_half_vector;
18-
1914
template <typename T>
2015
X86_SIMD_SORT_INLINE void std_argselect_withnan(
2116
T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right)
@@ -146,9 +141,9 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
146141
reg_t *smallest_vec,
147142
reg_t *biggest_vec)
148143
{
149-
if constexpr (sizeof (argreg_t) == 64){
144+
if constexpr (vtype::vec_type == simd_type::AVX512){
150145
return partition_vec_avx512<vtype, argtype, type_t>(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec);
151-
}else if constexpr (sizeof (argreg_t) == 32){
146+
}else if constexpr (vtype::vec_type == simd_type::AVX2){
152147
return partition_vec_avx2<vtype, argtype, type_t>(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec);
153148
}else{
154149
static_assert(sizeof(argreg_t) == 0, "Should not get here");

src/xss-network-keyvaluesort.hpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,16 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
362362

363363
kreg_t keyVecs[numVecs];
364364
ireg_t indexVecs[numVecs];
365+
366+
// Generate masks for loading and storing
367+
typename keyType::opmask_t ioMasks[numVecs - numVecs / 2];
368+
X86_SIMD_SORT_UNROLL_LOOP(64)
369+
for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) {
370+
uint64_t num_to_read
371+
= std::min((uint64_t)std::max(0, N - i * keyType::numlanes),
372+
(uint64_t)keyType::numlanes);
373+
ioMasks[j] = keyType::get_partial_loadmask(num_to_read);
374+
}
365375

366376
// Unmasked part of the load
367377
X86_SIMD_SORT_UNROLL_LOOP(64)
@@ -373,20 +383,13 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
373383
// Masked part of the load
374384
X86_SIMD_SORT_UNROLL_LOOP(64)
375385
for (int i = numVecs / 2; i < numVecs; i++) {
376-
uint64_t num_to_read
377-
= std::min((uint64_t)std::max(0, N - i * keyType::numlanes),
378-
(uint64_t)keyType::numlanes);
379-
380-
auto indexMask = indexType::get_partial_loadmask(num_to_read);
381-
auto keyMask = keyType::get_partial_loadmask(num_to_read);
382-
383386
indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(),
384-
indexMask,
387+
extend_mask<keyType, indexType>(ioMasks[i - numVecs/2]),
385388
indices + i * indexType::numlanes);
386389

387390
keyVecs[i] = keyType::template mask_i64gather<sizeof(
388391
typename keyType::type_t)>(
389-
keyType::zmm_max(), keyMask, indexVecs[i], keys);
392+
keyType::zmm_max(), ioMasks[i - numVecs / 2], indexVecs[i], keys);
390393
}
391394

392395
// Sort each loaded vector
@@ -406,13 +409,8 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
406409
// Masked part of the store
407410
X86_SIMD_SORT_UNROLL_LOOP(64)
408411
for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) {
409-
uint64_t num_to_read
410-
= std::min((uint64_t)std::max(0, N - i * keyType::numlanes),
411-
(uint64_t)keyType::numlanes);
412-
413-
auto indexMask = indexType::get_partial_loadmask(num_to_read);
414412
indexType::mask_storeu(
415-
indices + i * indexType::numlanes, indexMask, indexVecs[i]);
413+
indices + i * indexType::numlanes, extend_mask<keyType, indexType>(ioMasks[i - numVecs/2]), indexVecs[i]);
416414
}
417415
}
418416

0 commit comments

Comments
 (0)