Skip to content

Commit 38eaa01

Browse files
author
Raghuveer Devulapalli
committed
Move pivot selection to its own file
1 parent 6fb8f01 commit 38eaa01

File tree

3 files changed

+159
-185
lines changed

3 files changed

+159
-185
lines changed

src/avx512-common-qsort.h

Lines changed: 2 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@
103103

104104
typedef size_t arrsize_t;
105105

106+
#include "xss-pivot-selection.hpp"
107+
106108
template <typename type>
107109
struct zmm_vector;
108110

@@ -719,143 +721,9 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys,
719721
return l_store;
720722
}
721723

722-
template <typename vtype, typename type_t>
723-
X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr,
724-
const arrsize_t left,
725-
const arrsize_t right)
726-
{
727-
constexpr arrsize_t numSamples = vtype::numlanes;
728-
type_t samples[numSamples];
729-
730-
arrsize_t delta = (right - left) / numSamples;
731-
732-
for (int i = 0; i < numSamples; i++) {
733-
samples[i] = arr[left + i * delta];
734-
}
735-
736-
auto vec = vtype::loadu(samples);
737-
vec = vtype::sort_vec(vec);
738-
return ((type_t *)&vec)[numSamples / 2];
739-
}
740-
741-
template <typename vtype, typename type_t>
742-
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
743-
const arrsize_t left,
744-
const arrsize_t right)
745-
{
746-
// median of 32
747-
arrsize_t size = (right - left) / 32;
748-
type_t vec_arr[32] = {arr[left],
749-
arr[left + size],
750-
arr[left + 2 * size],
751-
arr[left + 3 * size],
752-
arr[left + 4 * size],
753-
arr[left + 5 * size],
754-
arr[left + 6 * size],
755-
arr[left + 7 * size],
756-
arr[left + 8 * size],
757-
arr[left + 9 * size],
758-
arr[left + 10 * size],
759-
arr[left + 11 * size],
760-
arr[left + 12 * size],
761-
arr[left + 13 * size],
762-
arr[left + 14 * size],
763-
arr[left + 15 * size],
764-
arr[left + 16 * size],
765-
arr[left + 17 * size],
766-
arr[left + 18 * size],
767-
arr[left + 19 * size],
768-
arr[left + 20 * size],
769-
arr[left + 21 * size],
770-
arr[left + 22 * size],
771-
arr[left + 23 * size],
772-
arr[left + 24 * size],
773-
arr[left + 25 * size],
774-
arr[left + 26 * size],
775-
arr[left + 27 * size],
776-
arr[left + 28 * size],
777-
arr[left + 29 * size],
778-
arr[left + 30 * size],
779-
arr[left + 31 * size]};
780-
typename vtype::reg_t rand_vec = vtype::loadu(vec_arr);
781-
typename vtype::reg_t sort = vtype::sort_vec(rand_vec);
782-
return ((type_t *)&sort)[16];
783-
}
784-
785-
template <typename vtype, typename type_t>
786-
X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
787-
const arrsize_t left,
788-
const arrsize_t right)
789-
{
790-
// median of 16
791-
arrsize_t size = (right - left) / 16;
792-
using reg_t = typename vtype::reg_t;
793-
type_t vec_arr[16] = {arr[left + size],
794-
arr[left + 2 * size],
795-
arr[left + 3 * size],
796-
arr[left + 4 * size],
797-
arr[left + 5 * size],
798-
arr[left + 6 * size],
799-
arr[left + 7 * size],
800-
arr[left + 8 * size],
801-
arr[left + 9 * size],
802-
arr[left + 10 * size],
803-
arr[left + 11 * size],
804-
arr[left + 12 * size],
805-
arr[left + 13 * size],
806-
arr[left + 14 * size],
807-
arr[left + 15 * size],
808-
arr[left + 16 * size]};
809-
reg_t rand_vec = vtype::loadu(vec_arr);
810-
reg_t sort = vtype::sort_vec(rand_vec);
811-
// pivot will never be a nan, since there are no nan's!
812-
return ((type_t *)&sort)[8];
813-
}
814-
815-
template <typename vtype, typename type_t>
816-
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
817-
const arrsize_t left,
818-
const arrsize_t right)
819-
{
820-
// median of 8
821-
arrsize_t size = (right - left) / 8;
822-
using reg_t = typename vtype::reg_t;
823-
reg_t rand_vec = vtype::set(arr[left + size],
824-
arr[left + 2 * size],
825-
arr[left + 3 * size],
826-
arr[left + 4 * size],
827-
arr[left + 5 * size],
828-
arr[left + 6 * size],
829-
arr[left + 7 * size],
830-
arr[left + 8 * size]);
831-
// pivot will never be a nan, since there are no nan's!
832-
reg_t sort = vtype::sort_vec(rand_vec);
833-
return ((type_t *)&sort)[4];
834-
}
835-
836-
template <typename vtype, typename type_t>
837-
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
838-
const arrsize_t left,
839-
const arrsize_t right)
840-
{
841-
if constexpr (vtype::numlanes == 8)
842-
return get_pivot_64bit<vtype>(arr, left, right);
843-
else if constexpr (vtype::numlanes == 16)
844-
return get_pivot_32bit<vtype>(arr, left, right);
845-
else if constexpr (vtype::numlanes == 32)
846-
return get_pivot_16bit<vtype>(arr, left, right);
847-
else
848-
return get_pivot_scalar<vtype>(arr, left, right);
849-
}
850-
851724
template <typename vtype, int maxN>
852725
void sort_n(typename vtype::type_t *arr, int N);
853726

854-
template <typename vtype, typename type_t>
855-
X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
856-
arrsize_t left,
857-
arrsize_t right);
858-
859727
template <typename vtype, typename type_t>
860728
static void
861729
qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)

src/xss-network-qsort.hpp

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -187,54 +187,4 @@ X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N)
187187

188188
sort_n_vec<vtype, numVecs>(arr, N);
189189
}
190-
191-
template <typename vtype, typename type_t>
192-
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
193-
const arrsize_t left,
194-
const arrsize_t right);
195-
196-
template <typename vtype, typename type_t>
197-
X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
198-
arrsize_t left,
199-
arrsize_t right)
200-
{
201-
202-
if (right - left <= 1024) { return get_pivot<vtype>(arr, left, right); }
203-
204-
using reg_t = typename vtype::reg_t;
205-
constexpr int numVecs = 5;
206-
207-
arrsize_t width = (right - vtype::numlanes) - left;
208-
arrsize_t delta = width / numVecs;
209-
210-
reg_t vecs[numVecs];
211-
// Load data
212-
for (int i = 0; i < numVecs; i++) {
213-
vecs[i] = vtype::loadu(arr + left + delta * i);
214-
}
215-
216-
// Implement sorting network (from https://bertdobbelaere.github.io/sorting_networks.html)
217-
COEX<vtype>(vecs[0], vecs[3]);
218-
COEX<vtype>(vecs[1], vecs[4]);
219-
220-
COEX<vtype>(vecs[0], vecs[2]);
221-
COEX<vtype>(vecs[1], vecs[3]);
222-
223-
COEX<vtype>(vecs[0], vecs[1]);
224-
COEX<vtype>(vecs[2], vecs[4]);
225-
226-
COEX<vtype>(vecs[1], vecs[2]);
227-
COEX<vtype>(vecs[3], vecs[4]);
228-
229-
COEX<vtype>(vecs[2], vecs[3]);
230-
231-
// Calculate median of the middle vector
232-
reg_t &vec = vecs[numVecs / 2];
233-
vec = vtype::sort_vec(vec);
234-
235-
type_t data[vtype::numlanes];
236-
vtype::storeu(data, vec);
237-
return data[vtype::numlanes / 2];
238-
}
239-
240-
#endif
190+
#endif

src/xss-pivot-selection.hpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
template <typename vtype, typename mm_t>
2+
X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b);
3+
4+
template <typename vtype, typename type_t>
5+
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
6+
const arrsize_t left,
7+
const arrsize_t right)
8+
{
9+
// median of 32
10+
arrsize_t size = (right - left) / 32;
11+
type_t vec_arr[32] = {arr[left],
12+
arr[left + size],
13+
arr[left + 2 * size],
14+
arr[left + 3 * size],
15+
arr[left + 4 * size],
16+
arr[left + 5 * size],
17+
arr[left + 6 * size],
18+
arr[left + 7 * size],
19+
arr[left + 8 * size],
20+
arr[left + 9 * size],
21+
arr[left + 10 * size],
22+
arr[left + 11 * size],
23+
arr[left + 12 * size],
24+
arr[left + 13 * size],
25+
arr[left + 14 * size],
26+
arr[left + 15 * size],
27+
arr[left + 16 * size],
28+
arr[left + 17 * size],
29+
arr[left + 18 * size],
30+
arr[left + 19 * size],
31+
arr[left + 20 * size],
32+
arr[left + 21 * size],
33+
arr[left + 22 * size],
34+
arr[left + 23 * size],
35+
arr[left + 24 * size],
36+
arr[left + 25 * size],
37+
arr[left + 26 * size],
38+
arr[left + 27 * size],
39+
arr[left + 28 * size],
40+
arr[left + 29 * size],
41+
arr[left + 30 * size],
42+
arr[left + 31 * size]};
43+
typename vtype::reg_t rand_vec = vtype::loadu(vec_arr);
44+
typename vtype::reg_t sort = vtype::sort_vec(rand_vec);
45+
return ((type_t *)&sort)[16];
46+
}
47+
48+
template <typename vtype, typename type_t>
49+
X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
50+
const arrsize_t left,
51+
const arrsize_t right)
52+
{
53+
// median of 16
54+
arrsize_t size = (right - left) / 16;
55+
using reg_t = typename vtype::reg_t;
56+
type_t vec_arr[16] = {arr[left + size],
57+
arr[left + 2 * size],
58+
arr[left + 3 * size],
59+
arr[left + 4 * size],
60+
arr[left + 5 * size],
61+
arr[left + 6 * size],
62+
arr[left + 7 * size],
63+
arr[left + 8 * size],
64+
arr[left + 9 * size],
65+
arr[left + 10 * size],
66+
arr[left + 11 * size],
67+
arr[left + 12 * size],
68+
arr[left + 13 * size],
69+
arr[left + 14 * size],
70+
arr[left + 15 * size],
71+
arr[left + 16 * size]};
72+
reg_t rand_vec = vtype::loadu(vec_arr);
73+
reg_t sort = vtype::sort_vec(rand_vec);
74+
// pivot will never be a nan, since there are no nan's!
75+
return ((type_t *)&sort)[8];
76+
}
77+
78+
template <typename vtype, typename type_t>
79+
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
80+
const arrsize_t left,
81+
const arrsize_t right)
82+
{
83+
// median of 8
84+
arrsize_t size = (right - left) / 8;
85+
using reg_t = typename vtype::reg_t;
86+
reg_t rand_vec = vtype::set(arr[left + size],
87+
arr[left + 2 * size],
88+
arr[left + 3 * size],
89+
arr[left + 4 * size],
90+
arr[left + 5 * size],
91+
arr[left + 6 * size],
92+
arr[left + 7 * size],
93+
arr[left + 8 * size]);
94+
// pivot will never be a nan, since there are no nan's!
95+
reg_t sort = vtype::sort_vec(rand_vec);
96+
return ((type_t *)&sort)[4];
97+
}
98+
99+
template <typename vtype, typename type_t>
100+
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
101+
const arrsize_t left,
102+
const arrsize_t right)
103+
{
104+
if constexpr (vtype::numlanes == 8)
105+
return get_pivot_64bit<vtype>(arr, left, right);
106+
else if constexpr (vtype::numlanes == 16)
107+
return get_pivot_32bit<vtype>(arr, left, right);
108+
else if constexpr (vtype::numlanes == 32)
109+
return get_pivot_16bit<vtype>(arr, left, right);
110+
else
111+
return arr[right];
112+
}
113+
114+
template <typename vtype, typename type_t>
115+
X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
116+
arrsize_t left,
117+
arrsize_t right)
118+
{
119+
120+
if (right - left <= 1024) { return get_pivot<vtype>(arr, left, right); }
121+
122+
using reg_t = typename vtype::reg_t;
123+
constexpr int numVecs = 5;
124+
125+
arrsize_t width = (right - vtype::numlanes) - left;
126+
arrsize_t delta = width / numVecs;
127+
128+
reg_t vecs[numVecs];
129+
// Load data
130+
for (int i = 0; i < numVecs; i++) {
131+
vecs[i] = vtype::loadu(arr + left + delta * i);
132+
}
133+
134+
// Implement sorting network (from https://bertdobbelaere.github.io/sorting_networks.html)
135+
COEX<vtype>(vecs[0], vecs[3]);
136+
COEX<vtype>(vecs[1], vecs[4]);
137+
138+
COEX<vtype>(vecs[0], vecs[2]);
139+
COEX<vtype>(vecs[1], vecs[3]);
140+
141+
COEX<vtype>(vecs[0], vecs[1]);
142+
COEX<vtype>(vecs[2], vecs[4]);
143+
144+
COEX<vtype>(vecs[1], vecs[2]);
145+
COEX<vtype>(vecs[3], vecs[4]);
146+
147+
COEX<vtype>(vecs[2], vecs[3]);
148+
149+
// Calculate median of the middle vector
150+
reg_t &vec = vecs[numVecs / 2];
151+
vec = vtype::sort_vec(vec);
152+
153+
type_t data[vtype::numlanes];
154+
vtype::storeu(data, vec);
155+
return data[vtype::numlanes / 2];
156+
}

0 commit comments

Comments
 (0)