Skip to content

Commit 649bd19

Browse files
Expanded types tests by TestQueueSubmitRange
Added tests for TestQueueSubmitNDRange
1 parent cf05f16 commit 649bd19

File tree

3 files changed

+402
-0
lines changed

3 files changed

+402
-0
lines changed

dpctl-capi/tests/dpcpp_kernels.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,85 @@
55
template sycl::kernel
66
dpcpp_kernels::get_fill_kernel<int>(sycl::queue &, size_t, int *, int);
77

8+
template sycl::kernel
9+
dpcpp_kernels::get_fill_kernel<unsigned int>(sycl::queue &,
10+
size_t,
11+
unsigned int *,
12+
unsigned int);
13+
14+
template sycl::kernel
15+
dpcpp_kernels::get_fill_kernel<double>(sycl::queue &, size_t, double *, double);
16+
17+
template sycl::kernel
18+
dpcpp_kernels::get_fill_kernel<float>(sycl::queue &, size_t, float *, float);
19+
820
template sycl::kernel
921
dpcpp_kernels::get_range_kernel<int>(sycl::queue &, size_t, int *);
1022

23+
template sycl::kernel
24+
dpcpp_kernels::get_range_kernel<unsigned int>(sycl::queue &,
25+
size_t,
26+
unsigned int *);
27+
28+
template sycl::kernel
29+
dpcpp_kernels::get_range_kernel<float>(sycl::queue &, size_t, float *);
30+
31+
template sycl::kernel
32+
dpcpp_kernels::get_range_kernel<double>(sycl::queue &, size_t, double *);
33+
1134
template sycl::kernel dpcpp_kernels::get_mad_kernel<int, int>(sycl::queue &,
1235
size_t,
1336
int *,
1437
int *,
1538
int *,
1639
int);
40+
41+
template sycl::kernel
42+
dpcpp_kernels::get_mad_kernel<unsigned int, unsigned int>(sycl::queue &,
43+
size_t,
44+
unsigned int *,
45+
unsigned int *,
46+
unsigned int *,
47+
unsigned int);
48+
49+
template sycl::kernel dpcpp_kernels::get_local_sort_kernel<int>(sycl::queue &,
50+
size_t,
51+
size_t,
52+
int *,
53+
size_t);
54+
55+
template sycl::kernel
56+
dpcpp_kernels::get_local_count_exceedance_kernel<int>(sycl::queue &,
57+
size_t,
58+
size_t,
59+
int *,
60+
size_t,
61+
int,
62+
int *);
63+
64+
template sycl::kernel
65+
dpcpp_kernels::get_local_count_exceedance_kernel<unsigned int>(sycl::queue &,
66+
size_t,
67+
size_t,
68+
unsigned int *,
69+
size_t,
70+
unsigned int,
71+
int *);
72+
73+
template sycl::kernel
74+
dpcpp_kernels::get_local_count_exceedance_kernel<float>(sycl::queue &,
75+
size_t,
76+
size_t,
77+
float *,
78+
size_t,
79+
float,
80+
int *);
81+
82+
template sycl::kernel
83+
dpcpp_kernels::get_local_count_exceedance_kernel<double>(sycl::queue &,
84+
size_t,
85+
size_t,
86+
double *,
87+
size_t,
88+
double,
89+
int *);

dpctl-capi/tests/dpcpp_kernels.hpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,195 @@ get_mad_kernel(sycl::queue &q, size_t n, T *in1, T *in2, T *out, scT val)
105105
return program.get_kernel<mad_kern<T, scT>>();
106106
};
107107

108+
template <typename name,
109+
typename localAccessorT,
110+
class KernelFuncArgs,
111+
class KernelFunctor>
112+
auto make_cgh_nd_function_with_local_memory(const sycl::nd_range<1> &nd_range,
113+
size_t slm_size,
114+
KernelFuncArgs kern_params)
115+
{
116+
auto Kernel = [&](sycl::handler &cgh) {
117+
localAccessorT lm(slm_size, cgh);
118+
cgh.parallel_for<name>(nd_range, KernelFunctor(kern_params, lm));
119+
};
120+
return Kernel;
121+
};
122+
123+
template <typename name, class KernelFunctor>
124+
auto make_cgh_nd_function(const sycl::nd_range<1> &nd_range, KernelFunctor kern)
125+
{
126+
auto Kernel = [&](sycl::handler &cgh) {
127+
cgh.parallel_for<name>(nd_range, kern);
128+
};
129+
return Kernel;
130+
};
131+
132+
template <typename T> struct LocalSortArgs
133+
{
134+
T *arr;
135+
size_t global_array_size;
136+
size_t wg_chunk_size;
137+
LocalSortArgs(T *arr, size_t arr_len, size_t wg_len)
138+
: arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len)
139+
{
140+
}
141+
~LocalSortArgs() {}
142+
143+
T *get_array_pointer() const
144+
{
145+
return arr;
146+
}
147+
size_t get_array_size() const
148+
{
149+
return global_array_size;
150+
}
151+
size_t get_chunk_size() const
152+
{
153+
return wg_chunk_size;
154+
}
155+
};
156+
157+
template <typename T, typename localAccessorT> struct LocalSortFunc
158+
{
159+
/*
160+
161+
*/
162+
T *arr;
163+
size_t global_array_size;
164+
size_t wg_chunk_size;
165+
localAccessorT lm;
166+
LocalSortFunc(T *arr, size_t arr_len, size_t wg_len, localAccessorT lm)
167+
: arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len), lm(lm)
168+
{
169+
}
170+
template <class paramsT>
171+
LocalSortFunc(paramsT params, localAccessorT lm)
172+
: arr(params.get_array_pointer()),
173+
global_array_size(params.get_array_size()),
174+
wg_chunk_size(params.get_chunk_size()), lm(lm)
175+
{
176+
}
177+
~LocalSortFunc() {}
178+
void operator()(sycl::nd_item<1> item) const
179+
{
180+
/* Use odd-even merge sort to sort lws chunk of array */
181+
size_t group_id = item.get_group_linear_id();
182+
size_t chunk_size =
183+
sycl::min((group_id + 1) * wg_chunk_size, global_array_size) -
184+
group_id * wg_chunk_size;
185+
186+
// compute the greatest power of 2 less than chunk_size
187+
size_t sp2 = 1;
188+
while (sp2 < chunk_size) {
189+
sp2 <<= 1;
190+
}
191+
sp2 >>= 1;
192+
193+
size_t gid = item.get_global_linear_id();
194+
size_t lid = item.get_local_linear_id();
195+
196+
if (gid < global_array_size) {
197+
lm[lid] = arr[gid];
198+
}
199+
item.barrier(sycl::access::fence_space::local_space);
200+
201+
for (size_t p = sp2; p > 0; p >>= 1) {
202+
size_t q = sp2;
203+
size_t r = 0;
204+
for (size_t d = p; d > 0; d = q - p, q >>= 1, r = p) {
205+
if ((lid < chunk_size - d) && (lid & p) == r) {
206+
size_t i = lid;
207+
size_t j = i + d;
208+
T v1 = lm[i];
209+
T v2 = lm[j];
210+
if (v1 > v2) {
211+
lm[i] = v2;
212+
lm[j] = v1;
213+
}
214+
}
215+
item.barrier(sycl::access::fence_space::local_space);
216+
}
217+
}
218+
if (gid < global_array_size) {
219+
arr[gid] = lm[lid];
220+
}
221+
};
222+
};
223+
224+
template <typename T> class local_sort_kern;
225+
226+
template <typename T>
227+
sycl::kernel get_local_sort_kernel(sycl::queue &q,
228+
size_t gws,
229+
size_t lws,
230+
T *arr,
231+
size_t arr_len)
232+
{
233+
sycl::program program(q.get_context());
234+
235+
using local_accessor_t =
236+
sycl::accessor<T, 1, sycl::access::mode::read_write,
237+
sycl::access::target::local>;
238+
239+
[[maybe_unused]] auto cgh_fn = make_cgh_nd_function_with_local_memory<
240+
local_sort_kern<T>, local_accessor_t, LocalSortArgs<T>,
241+
LocalSortFunc<T, local_accessor_t>>(
242+
sycl::nd_range<1>(gws, lws), lws, LocalSortArgs<T>(arr, arr_len, lws));
243+
244+
program.build_with_kernel_type<local_sort_kern<T>>();
245+
return program.get_kernel<local_sort_kern<T>>();
246+
};
247+
248+
template <typename T> struct LocalCountExceedanceFunc
249+
{
250+
T *arr;
251+
size_t arr_len;
252+
T threshold_val;
253+
int *count_arr;
254+
LocalCountExceedanceFunc(T *arr,
255+
size_t arr_len,
256+
T threshold_val,
257+
int *count_arr)
258+
: arr(arr), arr_len(arr_len), threshold_val(threshold_val),
259+
count_arr(count_arr)
260+
{
261+
}
262+
263+
void operator()(sycl::nd_item<1> item) const
264+
{
265+
/* count number of array elements in group chunk that
266+
exceeds the threshold value */
267+
size_t gid = item.get_global_linear_id();
268+
int partial_sum = sycl::ONEAPI::reduce(
269+
item.get_group(),
270+
(gid < arr_len) ? int(arr[gid] > threshold_val) : int(0),
271+
std::plus<int>());
272+
count_arr[item.get_group_linear_id()] = partial_sum;
273+
}
274+
};
275+
276+
template <typename T> class local_exceedance_kern;
277+
278+
template <typename T>
279+
sycl::kernel get_local_count_exceedance_kernel(sycl::queue &q,
280+
size_t gws,
281+
size_t lws,
282+
T *arr,
283+
size_t arr_len,
284+
T threshold_val,
285+
int *counts)
286+
{
287+
sycl::program program(q.get_context());
288+
289+
[[maybe_unused]] auto cgh_fn =
290+
make_cgh_nd_function<local_exceedance_kern<T>,
291+
LocalCountExceedanceFunc<T>>(
292+
sycl::nd_range<1>(gws, lws),
293+
LocalCountExceedanceFunc<T>(arr, arr_len, threshold_val, counts));
294+
295+
program.build_with_kernel_type<local_exceedance_kern<T>>();
296+
return program.get_kernel<local_exceedance_kern<T>>();
297+
};
298+
108299
} // namespace dpcpp_kernels

0 commit comments

Comments
 (0)