Skip to content

Commit d9c9737

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #103 from r-devulap/customsort-expt
Add API to sort array of custom objects
2 parents da8ce10 + fbc033e commit d9c9737

File tree

4 files changed

+154
-1
lines changed

4 files changed

+154
-1
lines changed

benchmarks/bench-all.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
#include "bench-qselect.hpp"
55
#include "bench-qsort.hpp"
66
#include "bench-keyvalue.hpp"
7+
#include "bench-objsort.hpp"

benchmarks/bench-objsort.hpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include <cmath>
2+
3+
static constexpr char x[] = "x";
4+
static constexpr char euclidean[] = "euclidean";
5+
static constexpr char taxicab[] = "taxicab";
6+
static constexpr char chebyshev[] = "chebyshev";
7+
8+
template <const char* val>
9+
struct Point3D {
10+
double x;
11+
double y;
12+
double z;
13+
static constexpr std::string_view name {val};
14+
Point3D()
15+
{
16+
x = (double)rand() / RAND_MAX;
17+
y = (double)rand() / RAND_MAX;
18+
z = (double)rand() / RAND_MAX;
19+
}
20+
double distance()
21+
{
22+
if constexpr (name == "x") {
23+
return x;
24+
}
25+
else if constexpr (name == "euclidean") {
26+
return std::sqrt(x * x + y * y + z * z);
27+
}
28+
else if constexpr (name == "taxicab") {
29+
return abs(x) + abs(y) + abs(z);
30+
}
31+
else if constexpr (name == "chebyshev") {
32+
return std::max(std::max(x, y), z);
33+
}
34+
}
35+
};
36+
37+
template <typename T>
38+
std::vector<T> init_data(const int size)
39+
{
40+
srand(42);
41+
std::vector<T> arr;
42+
for (auto ii = 0; ii < size; ++ii) {
43+
T temp;
44+
arr.push_back(temp);
45+
}
46+
return arr;
47+
}
48+
49+
template <typename T>
50+
struct less_than_key {
51+
inline bool operator()(T &p1, T &p2)
52+
{
53+
return (p1.distance() < p2.distance());
54+
}
55+
};
56+
57+
template <typename T>
58+
static void scalarobjsort(benchmark::State &state)
59+
{
60+
// set up array
61+
std::vector<T> arr = init_data<T>(state.range(0));
62+
std::vector<T> arr_bkp = arr;
63+
// benchmark
64+
for (auto _ : state) {
65+
std::sort(arr.begin(), arr.end(), less_than_key<T>());
66+
state.PauseTiming();
67+
arr = arr_bkp;
68+
state.ResumeTiming();
69+
}
70+
}
71+
72+
template <typename T>
73+
static void simdobjsort(benchmark::State &state)
74+
{
75+
// set up array
76+
std::vector<T> arr = init_data<T>(state.range(0));
77+
std::vector<T> arr_bkp = arr;
78+
// benchmark
79+
for (auto _ : state) {
80+
x86simdsort::object_qsort(arr.data(), arr.size(), [](T p) -> double {
81+
return p.distance();
82+
});
83+
state.PauseTiming();
84+
if (!std::is_sorted(arr.begin(), arr.end(), less_than_key<T>())) {
85+
std::cout << "sorting failed \n";
86+
}
87+
arr = arr_bkp;
88+
state.ResumeTiming();
89+
}
90+
}
91+
92+
#define BENCHMARK_OBJSORT(func, T) \
93+
BENCHMARK_TEMPLATE(func, T) \
94+
->Arg(10e1) \
95+
->Arg(10e2) \
96+
->Arg(10e3) \
97+
->Arg(10e4) \
98+
->Arg(10e5) \
99+
->Arg(10e6);
100+
101+
BENCHMARK_OBJSORT(simdobjsort, Point3D<x>)
102+
BENCHMARK_OBJSORT(scalarobjsort, Point3D<x>)
103+
BENCHMARK_OBJSORT(simdobjsort, Point3D<taxicab>)
104+
BENCHMARK_OBJSORT(scalarobjsort, Point3D<taxicab>)
105+
BENCHMARK_OBJSORT(simdobjsort, Point3D<euclidean>)
106+
BENCHMARK_OBJSORT(scalarobjsort, Point3D<euclidean>)
107+
BENCHMARK_OBJSORT(simdobjsort, Point3D<chebyshev>)
108+
BENCHMARK_OBJSORT(scalarobjsort, Point3D<chebyshev>)

lib/x86simdsort.h

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <stdint.h>
44
#include <vector>
55
#include <cstddef>
6+
#include <functional>
7+
#include <numeric>
68

79
#define XSS_EXPORT_SYMBOL __attribute__((visibility("default")))
810
#define XSS_HIDE_SYMBOL __attribute__((visibility("hidden")))
@@ -34,10 +36,49 @@ template <typename T>
3436
XSS_EXPORT_SYMBOL std::vector<size_t>
3537
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
3638

37-
// argselect
39+
// keyvalue sort
3840
template <typename T1, typename T2>
3941
XSS_EXPORT_SYMBOL void
4042
keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
4143

44+
// sort an object
45+
template <typename T, typename Func>
46+
XSS_EXPORT_SYMBOL void object_qsort(T *arr, size_t arrsize, Func key_func)
47+
{
48+
/* (1) Create a vector a keys */
49+
using return_type_of =
50+
typename decltype(std::function {key_func})::result_type;
51+
std::vector<return_type_of> keys;
52+
keys.reserve(arrsize);
53+
for (size_t ii = 0; ii < arrsize; ++ii) {
54+
keys[ii] = key_func(arr[ii]);
55+
}
56+
57+
/* (2) Call arg based on keys using the keyvalue sort */
58+
std::vector<size_t> arg(arrsize);
59+
std::iota(arg.begin(), arg.end(), 0);
60+
keyvalue_qsort(keys.data(), arg.data(), arrsize);
61+
62+
/* (3) Permute obj array in-place */
63+
std::vector<bool> done(arrsize);
64+
for (size_t i = 0; i < arrsize; ++i)
65+
{
66+
if (done[i])
67+
{
68+
continue;
69+
}
70+
done[i] = true;
71+
size_t prev_j = i;
72+
size_t j = arg[i];
73+
while (i != j)
74+
{
75+
std::swap(arr[prev_j], arr[j]);
76+
done[j] = true;
77+
prev_j = j;
78+
j = arg[j];
79+
}
80+
}
81+
}
82+
4283
} // namespace x86simdsort
4384
#endif

run-bench.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
elif "keyvalue" in args.benchcompare:
3838
baseline = "scalarkvsort.*" + filterb
3939
contender = "simdkvsort.*" + filterb
40+
elif "objsort" in args.benchcompare:
41+
baseline = "scalarobjsort.*" + filterb
42+
contender = "simdobjsort.*" + filterb
4043
else:
4144
parser.print_help(sys.stderr)
4245
parser.error("ERROR: Unknown argument '%s'" % args.benchcompare)

0 commit comments

Comments
 (0)