Skip to content

Commit

Permalink
Merge pull request #108 from r-devulap/kvsort-32bit
Browse files Browse the repository at this point in the history
Support key-value sort for 32-bit dtypes
  • Loading branch information
r-devulap authored Nov 28, 2023
2 parents 8187e9a + cb46165 commit da8ce10
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 57 deletions.
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,33 @@ AVX2 specific implementations, please see
[README](https://github.com/intel/x86-simd-sort/blob/main/src/README.md) file under
`src/` directory. The following routines are currently supported:


### Sort routines on arrays
```cpp
x86simdsort::qsort(T* arr, size_t size, bool hasnan);
x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
int32_t, double, uint64_t, int64_t]`
### Key-value sort routines on pairs of arrays
```cpp
x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan);
```
Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double,
uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit
data types.

### Arg sort routines on arrays
```cpp
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
int32_t, double, uint64_t, int64_t]`

### Build/Install
## Build/Install

[meson](https://github.com/mesonbuild/meson) is the used build system. Command
to build and install the library:
Expand All @@ -35,7 +53,7 @@ benchmark](https://github.com/google/benchmark) frameworks respectively. You
can configure meson to build them both by using `-Dbuild_tests=true` and
`-Dbuild_benchmarks=true`.

### Example usage
## Example usage

```cpp
#include "x86simdsort.h"
Expand All @@ -48,7 +66,7 @@ int main() {
```


### Details
## Details

- `x86simdsort::qsort` is equivalent to `qsort` in
[C](https://www.tutorialspoint.com/c_standard_library/c_function_qsort.htm)
Expand Down Expand Up @@ -77,7 +95,7 @@ argselect) will not use the SIMD based algorithms if they detect NAN's in the
array. You can read details of all the implementations
[here](https://github.com/intel/x86-simd-sort/src/README.md).

### Downstream projects using x86-simd-sort
## Downstream projects using x86-simd-sort

- NumPy uses this as a [submodule](https://github.com/numpy/numpy/pull/22315) to accelerate `np.sort, np.argsort, np.partition and np.argpartition`.
- A slightly modifed version this library has been integrated into [openJDK](https://github.com/openjdk/jdk/pull/14227).
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/bench-keyvalue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ static void simdkvsort(benchmark::State &state, Args &&...args)
BENCH_BOTH_KVSORT(uint64_t)
BENCH_BOTH_KVSORT(int64_t)
BENCH_BOTH_KVSORT(double)
BENCH_BOTH_KVSORT(uint32_t)
BENCH_BOTH_KVSORT(int32_t)
BENCH_BOTH_KVSORT(float)
6 changes: 5 additions & 1 deletion examples/avx512-kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ int main() {
int64_t arr1[size];
uint64_t arr2[size];
double arr3[size];
float arr4[size];
avx512_qsort_kv(arr1, arr1, size);
avx512_qsort_kv(arr1, arr2, size);
avx512_qsort_kv(arr1, arr3, size);
Expand All @@ -13,6 +14,9 @@ int main() {
avx512_qsort_kv(arr2, arr3, size);
avx512_qsort_kv(arr3, arr1, size);
avx512_qsort_kv(arr3, arr2, size);
avx512_qsort_kv(arr3, arr3, size);
avx512_qsort_kv(arr1, arr4, size);
avx512_qsort_kv(arr2, arr4, size);
avx512_qsort_kv(arr3, arr4, size);
return 0;
return 0;
}
44 changes: 33 additions & 11 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,34 @@
return avx512_argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type1, type2) \
#define DEFINE_KEYVALUE_METHODS(type) \
template <> \
void keyvalue_qsort(type1 *key, type2* val, size_t arrsize, bool hasnan) \
void keyvalue_qsort(type *key, uint64_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int64_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, double* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, uint32_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int32_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, float* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
Expand All @@ -49,14 +74,11 @@ namespace avx512 {
DEFINE_ALL_METHODS(uint64_t)
DEFINE_ALL_METHODS(int64_t)
DEFINE_ALL_METHODS(double)
DEFINE_KEYVALUE_METHODS(double, uint64_t)
DEFINE_KEYVALUE_METHODS(double, int64_t)
DEFINE_KEYVALUE_METHODS(double, double)
DEFINE_KEYVALUE_METHODS(uint64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, int64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, double)
DEFINE_KEYVALUE_METHODS(int64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t, int64_t)
DEFINE_KEYVALUE_METHODS(int64_t, double)
DEFINE_KEYVALUE_METHODS(uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t)
DEFINE_KEYVALUE_METHODS(double)
DEFINE_KEYVALUE_METHODS(uint32_t)
DEFINE_KEYVALUE_METHODS(int32_t)
DEFINE_KEYVALUE_METHODS(float)
} // namespace avx512
} // namespace xss
23 changes: 14 additions & 9 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,19 @@ DISPATCH_ALL(argselect,
(ISA_LIST("avx512_skx")),
(ISA_LIST("avx512_skx")))

DISPATCH_KEYVALUE_SORT(uint64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, uint64_t, (ISA_LIST("avx512_skx")))
#define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx")))\

DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(double)
DISPATCH_KEYVALUE_SORT_FORTYPE(uint32_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(int32_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(float)

} // namespace x86simdsort
12 changes: 12 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ struct ymm_vector<float> {
// return _mm256_shuffle_ps(zmm, zmm, mask);
//}
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_ps((float *)mem, x);
Expand Down Expand Up @@ -342,6 +346,10 @@ struct ymm_vector<uint32_t> {
* 32-bit and 64-bit */
return _mm256_shuffle_epi32(zmm, 0b10110001);
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_si256((__m256i *)mem, x);
Expand Down Expand Up @@ -498,6 +506,10 @@ struct ymm_vector<int32_t> {
* 32-bit and 64-bit */
return _mm256_shuffle_epi32(zmm, 0b10110001);
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_si256((__m256i *)mem, x);
Expand Down
39 changes: 23 additions & 16 deletions src/avx512-64bit-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ template <typename vtype1,
X86_SIMD_SORT_INLINE void
heap_sort(type1_t *keys, type2_t *indexes, arrsize_t size)
{
for (arrsize_t i = size / 2 - 1; ; i--) {
for (arrsize_t i = size / 2 - 1;; i--) {
heapify<vtype1, vtype2>(keys, indexes, i, size);
if (i == 0) { break; }
}
Expand Down Expand Up @@ -617,26 +617,33 @@ template <typename T1, typename T2>
X86_SIMD_SORT_INLINE void
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
{
UNUSED(hasnan);
using keytype = typename std::conditional<sizeof(T1) == sizeof(int32_t),
ymm_vector<T1>,
zmm_vector<T1>>::type;
using valtype = typename std::conditional<sizeof(T2) == sizeof(int32_t),
ymm_vector<T2>,
zmm_vector<T2>>::type;
if (arrsize > 1) {
if constexpr (std::is_floating_point_v<T1>) {
arrsize_t nan_count
= replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
arrsize_t nan_count = 0;
if (UNLIKELY(hasnan)) {
nan_count = replace_nan_with_inf<zmm_vector<double>>(keys,
arrsize);
}
qsort_64bit_<keytype, valtype>(keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
replace_inf_with_nan(keys, arrsize, nan_count);
}
else {
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
UNUSED(hasnan);
qsort_64bit_<keytype, valtype>(keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
}
}
}
Expand Down
36 changes: 20 additions & 16 deletions tests/test-keyvalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,32 @@ TYPED_TEST_P(simdkvsort, test_kvsort)
std::vector<T1> key_bckp = key;
std::vector<T2> val_bckp = val;
x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan);
xss::scalar::keyvalue_qsort(key_bckp.data(), val_bckp.data(), size, hasnan);
xss::scalar::keyvalue_qsort(
key_bckp.data(), val_bckp.data(), size, hasnan);
ASSERT_EQ(key, key_bckp);
const bool hasDuplicates = std::adjacent_find(key.begin(), key.end()) != key.end();
if (!hasDuplicates) {
ASSERT_EQ(val, val_bckp);
}
key.clear(); val.clear();
key_bckp.clear(); val_bckp.clear();
const bool hasDuplicates
= std::adjacent_find(key.begin(), key.end()) != key.end();
if (!hasDuplicates) { ASSERT_EQ(val, val_bckp); }
key.clear();
val.clear();
key_bckp.clear();
val_bckp.clear();
}
}
}

REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort);

using QKVSortTestTypes = testing::Types<std::tuple<double, double>,
std::tuple<double, uint64_t>,
std::tuple<double, int64_t>,
std::tuple<uint64_t, double>,
std::tuple<uint64_t, uint64_t>,
std::tuple<uint64_t, int64_t>,
std::tuple<int64_t, double>,
std::tuple<int64_t, uint64_t>,
std::tuple<int64_t, int64_t>>;
#define CREATE_TUPLES(type) \
std::tuple<double, type>, std::tuple<uint64_t, type>, \
std::tuple<int64_t, type>, std::tuple<float, type>, \
std::tuple<uint32_t, type>, std::tuple<int32_t, type>

using QKVSortTestTypes = testing::Types<CREATE_TUPLES(double),
CREATE_TUPLES(uint64_t),
CREATE_TUPLES(int64_t),
CREATE_TUPLES(uint32_t),
CREATE_TUPLES(int32_t),
CREATE_TUPLES(float)>;

INSTANTIATE_TYPED_TEST_SUITE_P(xss, simdkvsort, QKVSortTestTypes);

0 comments on commit da8ce10

Please sign in to comment.