Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support key-value sort for 32-bit dtypes #108

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,33 @@ AVX2 specific implementations, please see
[README](https://github.com/intel/x86-simd-sort/blob/main/src/README.md) file under
`src/` directory. The following routines are currently supported:


### Sort routines on arrays
```cpp
x86simdsort::qsort(T* arr, size_t size, bool hasnan);
x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
int32_t, double, uint64_t, int64_t]`
### Key-value sort routines on pairs of arrays
```cpp
x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan);
```
Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double,
uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit
data types.

### Arg sort routines on arrays
```cpp
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
int32_t, double, uint64_t, int64_t]`

### Build/Install
## Build/Install

[meson](https://github.com/mesonbuild/meson) is the used build system. Command
to build and install the library:
Expand All @@ -35,7 +53,7 @@ benchmark](https://github.com/google/benchmark) frameworks respectively. You
can configure meson to build them both by using `-Dbuild_tests=true` and
`-Dbuild_benchmarks=true`.

### Example usage
## Example usage

```cpp
#include "x86simdsort.h"
Expand All @@ -48,7 +66,7 @@ int main() {
```


### Details
## Details

- `x86simdsort::qsort` is equivalent to `qsort` in
[C](https://www.tutorialspoint.com/c_standard_library/c_function_qsort.htm)
Expand Down Expand Up @@ -77,7 +95,7 @@ argselect) will not use the SIMD based algorithms if they detect NAN's in the
array. You can read details of all the implementations
[here](https://github.com/intel/x86-simd-sort/src/README.md).

### Downstream projects using x86-simd-sort
## Downstream projects using x86-simd-sort

- NumPy uses this as a [submodule](https://github.com/numpy/numpy/pull/22315) to accelerate `np.sort, np.argsort, np.partition and np.argpartition`.
- A slightly modifed version this library has been integrated into [openJDK](https://github.com/openjdk/jdk/pull/14227).
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/bench-keyvalue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ static void simdkvsort(benchmark::State &state, Args &&...args)
BENCH_BOTH_KVSORT(uint64_t)
BENCH_BOTH_KVSORT(int64_t)
BENCH_BOTH_KVSORT(double)
BENCH_BOTH_KVSORT(uint32_t)
BENCH_BOTH_KVSORT(int32_t)
BENCH_BOTH_KVSORT(float)
6 changes: 5 additions & 1 deletion examples/avx512-kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ int main() {
int64_t arr1[size];
uint64_t arr2[size];
double arr3[size];
float arr4[size];
avx512_qsort_kv(arr1, arr1, size);
avx512_qsort_kv(arr1, arr2, size);
avx512_qsort_kv(arr1, arr3, size);
Expand All @@ -13,6 +14,9 @@ int main() {
avx512_qsort_kv(arr2, arr3, size);
avx512_qsort_kv(arr3, arr1, size);
avx512_qsort_kv(arr3, arr2, size);
avx512_qsort_kv(arr3, arr3, size);
avx512_qsort_kv(arr1, arr4, size);
avx512_qsort_kv(arr2, arr4, size);
avx512_qsort_kv(arr3, arr4, size);
return 0;
return 0;
}
44 changes: 33 additions & 11 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,34 @@
return avx512_argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type1, type2) \
#define DEFINE_KEYVALUE_METHODS(type) \
template <> \
void keyvalue_qsort(type1 *key, type2* val, size_t arrsize, bool hasnan) \
void keyvalue_qsort(type *key, uint64_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int64_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, double* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, uint32_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, int32_t* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
template <> \
void keyvalue_qsort(type *key, float* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \
Expand All @@ -49,14 +74,11 @@ namespace avx512 {
DEFINE_ALL_METHODS(uint64_t)
DEFINE_ALL_METHODS(int64_t)
DEFINE_ALL_METHODS(double)
DEFINE_KEYVALUE_METHODS(double, uint64_t)
DEFINE_KEYVALUE_METHODS(double, int64_t)
DEFINE_KEYVALUE_METHODS(double, double)
DEFINE_KEYVALUE_METHODS(uint64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, int64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, double)
DEFINE_KEYVALUE_METHODS(int64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t, int64_t)
DEFINE_KEYVALUE_METHODS(int64_t, double)
DEFINE_KEYVALUE_METHODS(uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t)
DEFINE_KEYVALUE_METHODS(double)
DEFINE_KEYVALUE_METHODS(uint32_t)
DEFINE_KEYVALUE_METHODS(int32_t)
DEFINE_KEYVALUE_METHODS(float)
} // namespace avx512
} // namespace xss
23 changes: 14 additions & 9 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,19 @@ DISPATCH_ALL(argselect,
(ISA_LIST("avx512_skx")),
(ISA_LIST("avx512_skx")))

DISPATCH_KEYVALUE_SORT(uint64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, uint64_t, (ISA_LIST("avx512_skx")))
#define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx")))\
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx")))\

DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(double)
DISPATCH_KEYVALUE_SORT_FORTYPE(uint32_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(int32_t)
DISPATCH_KEYVALUE_SORT_FORTYPE(float)

} // namespace x86simdsort
12 changes: 12 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ struct ymm_vector<float> {
// return _mm256_shuffle_ps(zmm, zmm, mask);
//}
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_ps((float *)mem, x);
Expand Down Expand Up @@ -342,6 +346,10 @@ struct ymm_vector<uint32_t> {
* 32-bit and 64-bit */
return _mm256_shuffle_epi32(zmm, 0b10110001);
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_si256((__m256i *)mem, x);
Expand Down Expand Up @@ -498,6 +506,10 @@ struct ymm_vector<int32_t> {
* 32-bit and 64-bit */
return _mm256_shuffle_epi32(zmm, 0b10110001);
}
static reg_t sort_vec(reg_t x)
{
return sort_zmm_64bit<ymm_vector<type_t>>(x);
}
static void storeu(void *mem, reg_t x)
{
_mm256_storeu_si256((__m256i *)mem, x);
Expand Down
39 changes: 23 additions & 16 deletions src/avx512-64bit-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ template <typename vtype1,
X86_SIMD_SORT_INLINE void
heap_sort(type1_t *keys, type2_t *indexes, arrsize_t size)
{
for (arrsize_t i = size / 2 - 1; ; i--) {
for (arrsize_t i = size / 2 - 1;; i--) {
heapify<vtype1, vtype2>(keys, indexes, i, size);
if (i == 0) { break; }
}
Expand Down Expand Up @@ -617,26 +617,33 @@ template <typename T1, typename T2>
X86_SIMD_SORT_INLINE void
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
{
UNUSED(hasnan);
using keytype = typename std::conditional<sizeof(T1) == sizeof(int32_t),
ymm_vector<T1>,
zmm_vector<T1>>::type;
using valtype = typename std::conditional<sizeof(T2) == sizeof(int32_t),
ymm_vector<T2>,
zmm_vector<T2>>::type;
if (arrsize > 1) {
if constexpr (std::is_floating_point_v<T1>) {
arrsize_t nan_count
= replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
arrsize_t nan_count = 0;
if (UNLIKELY(hasnan)) {
nan_count = replace_nan_with_inf<zmm_vector<double>>(keys,
arrsize);
}
qsort_64bit_<keytype, valtype>(keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
replace_inf_with_nan(keys, arrsize, nan_count);
}
else {
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
UNUSED(hasnan);
qsort_64bit_<keytype, valtype>(keys,
indexes,
0,
arrsize - 1,
2 * (arrsize_t)log2(arrsize));
}
}
}
Expand Down
36 changes: 20 additions & 16 deletions tests/test-keyvalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,32 @@ TYPED_TEST_P(simdkvsort, test_kvsort)
std::vector<T1> key_bckp = key;
std::vector<T2> val_bckp = val;
x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan);
xss::scalar::keyvalue_qsort(key_bckp.data(), val_bckp.data(), size, hasnan);
xss::scalar::keyvalue_qsort(
key_bckp.data(), val_bckp.data(), size, hasnan);
ASSERT_EQ(key, key_bckp);
const bool hasDuplicates = std::adjacent_find(key.begin(), key.end()) != key.end();
if (!hasDuplicates) {
ASSERT_EQ(val, val_bckp);
}
key.clear(); val.clear();
key_bckp.clear(); val_bckp.clear();
const bool hasDuplicates
= std::adjacent_find(key.begin(), key.end()) != key.end();
if (!hasDuplicates) { ASSERT_EQ(val, val_bckp); }
key.clear();
val.clear();
key_bckp.clear();
val_bckp.clear();
}
}
}

REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort);

using QKVSortTestTypes = testing::Types<std::tuple<double, double>,
std::tuple<double, uint64_t>,
std::tuple<double, int64_t>,
std::tuple<uint64_t, double>,
std::tuple<uint64_t, uint64_t>,
std::tuple<uint64_t, int64_t>,
std::tuple<int64_t, double>,
std::tuple<int64_t, uint64_t>,
std::tuple<int64_t, int64_t>>;
#define CREATE_TUPLES(type) \
std::tuple<double, type>, std::tuple<uint64_t, type>, \
std::tuple<int64_t, type>, std::tuple<float, type>, \
std::tuple<uint32_t, type>, std::tuple<int32_t, type>

using QKVSortTestTypes = testing::Types<CREATE_TUPLES(double),
CREATE_TUPLES(uint64_t),
CREATE_TUPLES(int64_t),
CREATE_TUPLES(uint32_t),
CREATE_TUPLES(int32_t),
CREATE_TUPLES(float)>;

INSTANTIATE_TYPED_TEST_SUITE_P(xss, simdkvsort, QKVSortTestTypes);
Loading