Skip to content

Commit c25661e

Browse files
author
Raghuveer Devulapalli
committed
Fix scalar methods to treat NAN correctlt
1 parent b49a0f8 commit c25661e

File tree

4 files changed

+60
-56
lines changed

4 files changed

+60
-56
lines changed

lib/x86simdsort-scalar.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
#include <algorithm>
22
#include <numeric>
3-
#define UNUSED(x) (void)(x)
3+
#include "custom-compare.h"
4+
45
namespace xss {
56
namespace scalar {
6-
/* TODO: handle NAN */
77
template <typename T>
88
void qsort(T *arr, int64_t arrsize)
99
{
10-
std::sort(arr, arr + arrsize);
10+
std::sort(arr, arr + arrsize, compare<T, std::less<T>>());
1111
}
1212
template <typename T>
1313
void qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan)
1414
{
15-
UNUSED(hasnan);
16-
std::nth_element(arr, arr + k, arr + arrsize);
15+
if (hasnan) {
16+
std::nth_element(arr, arr + k, arr + arrsize, compare<T, std::less<T>>());
17+
}
18+
else {
19+
std::nth_element(arr, arr + k, arr + arrsize);
20+
}
1721
}
1822
template <typename T>
1923
void partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan)
2024
{
21-
UNUSED(hasnan);
22-
std::partial_sort(arr, arr + k, arr + arrsize);
25+
if (hasnan) {
26+
std::partial_sort(arr, arr + k, arr + arrsize, compare<T, std::less<T>>());
27+
}
28+
else {
29+
std::partial_sort(arr, arr + k, arr + arrsize);
30+
}
2331
}
2432
template <typename T>
2533
std::vector<int64_t> argsort(T *arr, int64_t arrsize)
2634
{
2735
std::vector<int64_t> arg(arrsize);
2836
std::iota(arg.begin(), arg.end(), 0);
29-
std::sort(arg.begin(),
30-
arg.end(),
31-
[arr](int64_t left, int64_t right) -> bool {
32-
return arr[left] < arr[right];
33-
});
37+
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
3438
return arg;
3539
}
3640
template <typename T>
@@ -41,9 +45,7 @@ namespace scalar {
4145
std::nth_element(arg.begin(),
4246
arg.begin() + k,
4347
arg.end(),
44-
[arr](int64_t left, int64_t right) -> bool {
45-
return arr[left] < arr[right];
46-
});
48+
compare_arg<T, std::less<T>>(arr));
4749
return arg;
4850
}
4951

meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ subdir('benchmarks')
2626

2727
libsimdsort = shared_library('x86simdsort',
2828
'lib/x86simdsort.cpp',
29-
include_directories : [lib],
29+
include_directories : [utils, lib],
3030
link_whole : [libtargets],
3131
cpp_args : ['-O3'],
3232
)

tests/test-qsort-common.h

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define AVX512_TEST_COMMON
33

44
#include "rand_array.h"
5+
#include "custom-compare.h"
56
#include "x86simdsort.h"
67
#include <gtest/gtest.h>
78

@@ -17,46 +18,6 @@
1718
ASSERT_TRUE(false) << msg << ". arr size = " << size \
1819
<< ", type = " << type << ", k = " << k;
1920

20-
/*
21-
* Custom comparator class to handle NAN's: treats NAN > INF
22-
*/
23-
template <typename T, typename Comparator>
24-
struct compare {
25-
static constexpr auto op = Comparator {};
26-
bool operator()(const T a, const T b)
27-
{
28-
if constexpr (std::is_floating_point_v<T>) {
29-
T inf = std::numeric_limits<T>::infinity();
30-
if (!std::isunordered(a, b)) { return op(a, b); }
31-
else if ((std::isnan(a)) && (!std::isnan(b))) {
32-
return b == inf ? op(inf, 1.) : op(inf, b);
33-
}
34-
else if ((!std::isnan(a)) && (std::isnan(b))) {
35-
return a == inf ? op(1., inf) : op(a, inf);
36-
}
37-
else {
38-
return op(1., 1.);
39-
}
40-
}
41-
else {
42-
return op(a, b);
43-
}
44-
}
45-
};
46-
47-
//template <typename T, typename Comparator>
48-
//struct compare_arg {
49-
// compare_arg(std::vector<T> arr)
50-
// {
51-
// this->arr = arr;
52-
// }
53-
// bool operator()(const int64_t a, const int64_t b)
54-
// {
55-
// return compare<T, Comparator>()(arr[a], arr[b]);
56-
// }
57-
// std::vector<T> arr;
58-
//};
59-
6021
template <typename T>
6122
void IS_SORTED(std::vector<T> sorted, std::vector<T> arr, std::string type)
6223
{

utils/custom-compare.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <limits>
2+
#include <cmath>
3+
/*
4+
* Custom comparator class to handle NAN's: treats NAN > INF
5+
*/
6+
template <typename T, typename Comparator>
7+
struct compare {
8+
static constexpr auto op = Comparator {};
9+
bool operator()(const T a, const T b)
10+
{
11+
if constexpr (std::is_floating_point_v<T>) {
12+
T inf = std::numeric_limits<T>::infinity();
13+
if (!std::isunordered(a, b)) { return op(a, b); }
14+
else if ((std::isnan(a)) && (!std::isnan(b))) {
15+
return b == inf ? op(inf, 1.) : op(inf, b);
16+
}
17+
else if ((!std::isnan(a)) && (std::isnan(b))) {
18+
return a == inf ? op(1., inf) : op(a, inf);
19+
}
20+
else {
21+
return op(1., 1.);
22+
}
23+
}
24+
else {
25+
return op(a, b);
26+
}
27+
}
28+
};
29+
30+
template <typename T, typename Comparator>
31+
struct compare_arg {
32+
compare_arg(const T* arr)
33+
{
34+
this->arr = arr;
35+
}
36+
bool operator()(const int64_t a, const int64_t b)
37+
{
38+
return compare<T, Comparator>()(arr[a], arr[b]);
39+
}
40+
const T* arr;
41+
};

0 commit comments

Comments
 (0)