Skip to content

Commit 31b3b40

Browse files
committed
Fixes/changes many small things
1 parent ff9b0ea commit 31b3b40

File tree

5 files changed

+39
-115
lines changed

5 files changed

+39
-115
lines changed

src/avx2-32bit-common.h

Lines changed: 19 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,6 @@
2222
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
2323
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4
2424

25-
namespace xss {
26-
namespace avx2 {
27-
28-
// Assumes ymm is bitonic and performs a recursive half cleaner
29-
template <typename vtype, typename reg_t = typename vtype::reg_t>
30-
X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit(reg_t ymm)
31-
{
32-
33-
const typename vtype::opmask_t oxAA = _mm256_set_epi32(
34-
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
35-
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
36-
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
37-
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
38-
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);
39-
40-
// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
41-
ymm = cmp_merge<vtype>(
42-
ymm,
43-
vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_4), ymm),
44-
oxF0);
45-
// 2) half_cleaner[4]
46-
ymm = cmp_merge<vtype>(
47-
ymm,
48-
vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_3), ymm),
49-
oxCC);
50-
// 3) half_cleaner[1]
51-
ymm = cmp_merge<vtype>(
52-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
53-
return ymm;
54-
}
55-
5625
/*
5726
* Assumes ymm is random and performs a full sorting network defined in
5827
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
@@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
8554
struct avx2_32bit_swizzle_ops;
8655

8756
template <>
88-
struct ymm_vector<int32_t> {
57+
struct avx2_vector<int32_t> {
8958
using type_t = int32_t;
9059
using reg_t = __m256i;
9160
using ymmi_t = __m256i;
@@ -231,13 +200,9 @@ struct ymm_vector<int32_t> {
231200
{
232201
_mm256_storeu_si256((__m256i *)mem, x);
233202
}
234-
static reg_t bitonic_merge(reg_t x)
235-
{
236-
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
237-
}
238203
static reg_t sort_vec(reg_t x)
239204
{
240-
return sort_ymm_32bit<ymm_vector<type_t>>(x);
205+
return sort_ymm_32bit<avx2_vector<type_t>>(x);
241206
}
242207
static reg_t cast_from(__m256i v){
243208
return v;
@@ -247,7 +212,7 @@ struct ymm_vector<int32_t> {
247212
}
248213
};
249214
template <>
250-
struct ymm_vector<uint32_t> {
215+
struct avx2_vector<uint32_t> {
251216
using type_t = uint32_t;
252217
using reg_t = __m256i;
253218
using ymmi_t = __m256i;
@@ -378,13 +343,9 @@ struct ymm_vector<uint32_t> {
378343
{
379344
_mm256_storeu_si256((__m256i *)mem, x);
380345
}
381-
static reg_t bitonic_merge(reg_t x)
382-
{
383-
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
384-
}
385346
static reg_t sort_vec(reg_t x)
386347
{
387-
return sort_ymm_32bit<ymm_vector<type_t>>(x);
348+
return sort_ymm_32bit<avx2_vector<type_t>>(x);
388349
}
389350
static reg_t cast_from(__m256i v){
390351
return v;
@@ -394,7 +355,7 @@ struct ymm_vector<uint32_t> {
394355
}
395356
};
396357
template <>
397-
struct ymm_vector<float> {
358+
struct avx2_vector<float> {
398359
using type_t = float;
399360
using reg_t = __m256;
400361
using ymmi_t = __m256i;
@@ -440,6 +401,19 @@ struct ymm_vector<float> {
440401
{
441402
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ));
442403
}
404+
static opmask_t get_partial_loadmask(int size)
405+
{
406+
return (0x0001 << size) - 0x0001;
407+
}
408+
template <int type>
409+
static opmask_t fpclass(reg_t x)
410+
{
411+
if constexpr (type == (0x01 | 0x80)){
412+
return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q));
413+
}else{
414+
static_assert(type == (0x01 | 0x80), "should not reach here");
415+
}
416+
}
443417
template <int scale>
444418
static reg_t
445419
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
@@ -533,13 +507,9 @@ struct ymm_vector<float> {
533507
{
534508
_mm256_storeu_ps((float *)mem, x);
535509
}
536-
static reg_t bitonic_merge(reg_t x)
537-
{
538-
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
539-
}
540510
static reg_t sort_vec(reg_t x)
541511
{
542-
return sort_ymm_32bit<ymm_vector<type_t>>(x);
512+
return sort_ymm_32bit<avx2_vector<type_t>>(x);
543513
}
544514
static reg_t cast_from(__m256i v){
545515
return _mm256_castsi256_ps(v);
@@ -549,32 +519,6 @@ struct ymm_vector<float> {
549519
}
550520
};
551521

552-
inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize)
553-
{
554-
arrsize_t nan_count = 0;
555-
__mmask8 loadmask = 0xFF;
556-
while (arrsize > 0) {
557-
if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; }
558-
__m256 in_ymm = ymm_vector<float>::maskz_loadu(loadmask, arr);
559-
__m256i nanmask = _mm256_castps_si256(
560-
_mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ));
561-
nan_count += _mm_popcnt_u32(avx2_mask_helper32(nanmask));
562-
ymm_vector<float>::mask_storeu(arr, nanmask, YMM_MAX_FLOAT);
563-
arr += 8;
564-
arrsize -= 8;
565-
}
566-
return nan_count;
567-
}
568-
569-
X86_SIMD_SORT_INLINE void
570-
replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count)
571-
{
572-
for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) {
573-
arr[ii] = std::nan("1");
574-
nan_count -= 1;
575-
}
576-
}
577-
578522
struct avx2_32bit_swizzle_ops{
579523
template <typename vtype, int scale>
580524
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){
@@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{
635579
return vtype::cast_from(v1);
636580
}
637581
};
638-
639-
} // namespace avx2
640-
} // namespace xss
641582
#endif

src/avx2-emu-funcs.hpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
#include <utility>
66
#include "xss-common-qsort.h"
77

8-
namespace xss {
9-
namespace avx2 {
10-
118
constexpr auto avx2_mask_helper_lut32 = [] {
129
std::array<std::array<int32_t, 8>, 256> lut {};
1310
for (int64_t i = 0; i <= 0xFF; i++) {
@@ -97,9 +94,9 @@ static __m256i operator~(const avx2_mask_helper32 x)
9794

9895
// Emulators for intrinsics missing from AVX2 compared to AVX512
9996
template <typename T>
100-
T avx2_emu_reduce_max32(typename ymm_vector<T>::reg_t x)
97+
T avx2_emu_reduce_max32(typename avx2_vector<T>::reg_t x)
10198
{
102-
using vtype = ymm_vector<T>;
99+
using vtype = avx2_vector<T>;
103100
using reg_t = typename vtype::reg_t;
104101

105102
reg_t inter1 = vtype::max(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
@@ -110,9 +107,9 @@ T avx2_emu_reduce_max32(typename ymm_vector<T>::reg_t x)
110107
}
111108

112109
template <typename T>
113-
T avx2_emu_reduce_min32(typename ymm_vector<T>::reg_t x)
110+
T avx2_emu_reduce_min32(typename avx2_vector<T>::reg_t x)
114111
{
115-
using vtype = ymm_vector<T>;
112+
using vtype = avx2_vector<T>;
116113
using reg_t = typename vtype::reg_t;
117114

118115
reg_t inter1 = vtype::min(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
@@ -124,10 +121,10 @@ T avx2_emu_reduce_min32(typename ymm_vector<T>::reg_t x)
124121

125122
template <typename T>
126123
void avx2_emu_mask_compressstoreu(void *base_addr,
127-
typename ymm_vector<T>::opmask_t k,
128-
typename ymm_vector<T>::reg_t reg)
124+
typename avx2_vector<T>::opmask_t k,
125+
typename avx2_vector<T>::reg_t reg)
129126
{
130-
using vtype = ymm_vector<T>;
127+
using vtype = avx2_vector<T>;
131128

132129
T *leftStore = (T *)base_addr;
133130

@@ -145,10 +142,10 @@ void avx2_emu_mask_compressstoreu(void *base_addr,
145142
template <typename T>
146143
int32_t avx2_double_compressstore32(void *left_addr,
147144
void *right_addr,
148-
typename ymm_vector<T>::opmask_t k,
149-
typename ymm_vector<T>::reg_t reg)
145+
typename avx2_vector<T>::opmask_t k,
146+
typename avx2_vector<T>::reg_t reg)
150147
{
151-
using vtype = ymm_vector<T>;
148+
using vtype = avx2_vector<T>;
152149

153150
T *leftStore = (T *)left_addr;
154151
T *rightStore = (T *)right_addr;
@@ -168,27 +165,25 @@ int32_t avx2_double_compressstore32(void *left_addr,
168165
}
169166

170167
template <typename T>
171-
typename ymm_vector<T>::reg_t avx2_emu_max(typename ymm_vector<T>::reg_t x,
172-
typename ymm_vector<T>::reg_t y)
168+
typename avx2_vector<T>::reg_t avx2_emu_max(typename avx2_vector<T>::reg_t x,
169+
typename avx2_vector<T>::reg_t y)
173170
{
174-
using vtype = ymm_vector<T>;
171+
using vtype = avx2_vector<T>;
175172
typename vtype::opmask_t nlt = vtype::ge(x, y);
176173
return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y),
177174
_mm256_castsi256_pd(x),
178175
_mm256_castsi256_pd(nlt)));
179176
}
180177

181178
template <typename T>
182-
typename ymm_vector<T>::reg_t avx2_emu_min(typename ymm_vector<T>::reg_t x,
183-
typename ymm_vector<T>::reg_t y)
179+
typename avx2_vector<T>::reg_t avx2_emu_min(typename avx2_vector<T>::reg_t x,
180+
typename avx2_vector<T>::reg_t y)
184181
{
185-
using vtype = ymm_vector<T>;
182+
using vtype = avx2_vector<T>;
186183
typename vtype::opmask_t nlt = vtype::ge(x, y);
187184
return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x),
188185
_mm256_castsi256_pd(y),
189186
_mm256_castsi256_pd(nlt)));
190187
}
191-
} // namespace avx2
192-
} // namespace x86_simd_sort
193188

194189
#endif

src/avx512-16bit-common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#define AVX512_16BIT_COMMON
99

1010
#include "xss-common-qsort.h"
11-
#include "xss-network-qsort.hpp"
1211

1312
/*
1413
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic

src/avx512-32bit-qsort.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#define AVX512_QSORT_32BIT
1010

1111
#include "xss-common-qsort.h"
12-
#include "xss-network-qsort.hpp"
1312

1413
/*
1514
* Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic

src/xss-common-qsort.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,8 @@ struct zmm_vector;
114114
template <typename type>
115115
struct ymm_vector;
116116

117-
namespace xss{
118-
namespace avx2{
119117
template <typename type>
120-
struct ymm_vector;
121-
122-
inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize);
123-
}
124-
}
125-
126-
// key-value sort routines
127-
template <typename T1, typename T2>
128-
void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize);
118+
struct avx2_vector;
129119

130120
template <typename T>
131121
bool is_a_nan(T elem)
@@ -897,12 +887,12 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize)
897887
template <typename T>
898888
void avx2_qsort(T *arr, arrsize_t arrsize)
899889
{
900-
using vtype = xss::avx2::ymm_vector<T>;
890+
using vtype = avx2_vector<T>;
901891
if (arrsize > 1) {
902892
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
903893
if constexpr (std::is_floating_point_v<T>) {
904894
arrsize_t nan_count
905-
= xss::avx2::replace_nan_with_inf(arr, arrsize);
895+
= replace_nan_with_inf<vtype>(arr, arrsize);
906896
qsort_<vtype, T>(
907897
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
908898
replace_inf_with_nan(arr, arrsize, nan_count);
@@ -944,7 +934,7 @@ void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
944934
}
945935
UNUSED(hasnan);
946936
if (indx_last_elem >= k) {
947-
qselect_<xss::avx2::ymm_vector<T>, T>(
937+
qselect_<avx2_vector<T>, T>(
948938
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
949939
}
950940
}

0 commit comments

Comments
 (0)