@@ -105,4 +105,195 @@ get_mad_kernel(sycl::queue &q, size_t n, T *in1, T *in2, T *out, scT val)
105
105
return program.get_kernel <mad_kern<T, scT>>();
106
106
};
107
107
108
+ template <typename name,
109
+ typename localAccessorT,
110
+ class KernelFuncArgs ,
111
+ class KernelFunctor >
112
+ auto make_cgh_nd_function_with_local_memory (const sycl::nd_range<1 > &nd_range,
113
+ size_t slm_size,
114
+ KernelFuncArgs kern_params)
115
+ {
116
+ auto Kernel = [&](sycl::handler &cgh) {
117
+ localAccessorT lm (slm_size, cgh);
118
+ cgh.parallel_for <name>(nd_range, KernelFunctor (kern_params, lm));
119
+ };
120
+ return Kernel;
121
+ };
122
+
123
+ template <typename name, class KernelFunctor >
124
+ auto make_cgh_nd_function (const sycl::nd_range<1 > &nd_range, KernelFunctor kern)
125
+ {
126
+ auto Kernel = [&](sycl::handler &cgh) {
127
+ cgh.parallel_for <name>(nd_range, kern);
128
+ };
129
+ return Kernel;
130
+ };
131
+
132
+ template <typename T> struct LocalSortArgs
133
+ {
134
+ T *arr;
135
+ size_t global_array_size;
136
+ size_t wg_chunk_size;
137
+ LocalSortArgs (T *arr, size_t arr_len, size_t wg_len)
138
+ : arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len)
139
+ {
140
+ }
141
+ ~LocalSortArgs () {}
142
+
143
+ T *get_array_pointer () const
144
+ {
145
+ return arr;
146
+ }
147
+ size_t get_array_size () const
148
+ {
149
+ return global_array_size;
150
+ }
151
+ size_t get_chunk_size () const
152
+ {
153
+ return wg_chunk_size;
154
+ }
155
+ };
156
+
157
+ template <typename T, typename localAccessorT> struct LocalSortFunc
158
+ {
159
+ /*
160
+
161
+ */
162
+ T *arr;
163
+ size_t global_array_size;
164
+ size_t wg_chunk_size;
165
+ localAccessorT lm;
166
+ LocalSortFunc (T *arr, size_t arr_len, size_t wg_len, localAccessorT lm)
167
+ : arr(arr), global_array_size(arr_len), wg_chunk_size(wg_len), lm(lm)
168
+ {
169
+ }
170
+ template <class paramsT >
171
+ LocalSortFunc (paramsT params, localAccessorT lm)
172
+ : arr(params.get_array_pointer()),
173
+ global_array_size (params.get_array_size()),
174
+ wg_chunk_size(params.get_chunk_size()), lm(lm)
175
+ {
176
+ }
177
+ ~LocalSortFunc () {}
178
+ void operator ()(sycl::nd_item<1 > item) const
179
+ {
180
+ /* Use odd-even merge sort to sort lws chunk of array */
181
+ size_t group_id = item.get_group_linear_id ();
182
+ size_t chunk_size =
183
+ sycl::min ((group_id + 1 ) * wg_chunk_size, global_array_size) -
184
+ group_id * wg_chunk_size;
185
+
186
+ // compute the greatest power of 2 less than chunk_size
187
+ size_t sp2 = 1 ;
188
+ while (sp2 < chunk_size) {
189
+ sp2 <<= 1 ;
190
+ }
191
+ sp2 >>= 1 ;
192
+
193
+ size_t gid = item.get_global_linear_id ();
194
+ size_t lid = item.get_local_linear_id ();
195
+
196
+ if (gid < global_array_size) {
197
+ lm[lid] = arr[gid];
198
+ }
199
+ item.barrier (sycl::access::fence_space::local_space);
200
+
201
+ for (size_t p = sp2; p > 0 ; p >>= 1 ) {
202
+ size_t q = sp2;
203
+ size_t r = 0 ;
204
+ for (size_t d = p; d > 0 ; d = q - p, q >>= 1 , r = p) {
205
+ if ((lid < chunk_size - d) && (lid & p) == r) {
206
+ size_t i = lid;
207
+ size_t j = i + d;
208
+ T v1 = lm[i];
209
+ T v2 = lm[j];
210
+ if (v1 > v2) {
211
+ lm[i] = v2;
212
+ lm[j] = v1;
213
+ }
214
+ }
215
+ item.barrier (sycl::access::fence_space::local_space);
216
+ }
217
+ }
218
+ if (gid < global_array_size) {
219
+ arr[gid] = lm[lid];
220
+ }
221
+ };
222
+ };
223
+
224
+ template <typename T> class local_sort_kern ;
225
+
226
+ template <typename T>
227
+ sycl::kernel get_local_sort_kernel (sycl::queue &q,
228
+ size_t gws,
229
+ size_t lws,
230
+ T *arr,
231
+ size_t arr_len)
232
+ {
233
+ sycl::program program (q.get_context ());
234
+
235
+ using local_accessor_t =
236
+ sycl::accessor<T, 1 , sycl::access::mode::read_write,
237
+ sycl::access::target::local>;
238
+
239
+ [[maybe_unused]] auto cgh_fn = make_cgh_nd_function_with_local_memory<
240
+ local_sort_kern<T>, local_accessor_t , LocalSortArgs<T>,
241
+ LocalSortFunc<T, local_accessor_t >>(
242
+ sycl::nd_range<1 >(gws, lws), lws, LocalSortArgs<T>(arr, arr_len, lws));
243
+
244
+ program.build_with_kernel_type <local_sort_kern<T>>();
245
+ return program.get_kernel <local_sort_kern<T>>();
246
+ };
247
+
248
+ template <typename T> struct LocalCountExceedanceFunc
249
+ {
250
+ T *arr;
251
+ size_t arr_len;
252
+ T threshold_val;
253
+ int *count_arr;
254
+ LocalCountExceedanceFunc (T *arr,
255
+ size_t arr_len,
256
+ T threshold_val,
257
+ int *count_arr)
258
+ : arr(arr), arr_len(arr_len), threshold_val(threshold_val),
259
+ count_arr (count_arr)
260
+ {
261
+ }
262
+
263
+ void operator ()(sycl::nd_item<1 > item) const
264
+ {
265
+ /* count number of array elements in group chunk that
266
+ exceeds the threshold value */
267
+ size_t gid = item.get_global_linear_id ();
268
+ int partial_sum = sycl::ONEAPI::reduce (
269
+ item.get_group (),
270
+ (gid < arr_len) ? int (arr[gid] > threshold_val) : int (0 ),
271
+ std::plus<int >());
272
+ count_arr[item.get_group_linear_id ()] = partial_sum;
273
+ }
274
+ };
275
+
276
+ template <typename T> class local_exceedance_kern ;
277
+
278
+ template <typename T>
279
+ sycl::kernel get_local_count_exceedance_kernel (sycl::queue &q,
280
+ size_t gws,
281
+ size_t lws,
282
+ T *arr,
283
+ size_t arr_len,
284
+ T threshold_val,
285
+ int *counts)
286
+ {
287
+ sycl::program program (q.get_context ());
288
+
289
+ [[maybe_unused]] auto cgh_fn =
290
+ make_cgh_nd_function<local_exceedance_kern<T>,
291
+ LocalCountExceedanceFunc<T>>(
292
+ sycl::nd_range<1 >(gws, lws),
293
+ LocalCountExceedanceFunc<T>(arr, arr_len, threshold_val, counts));
294
+
295
+ program.build_with_kernel_type <local_exceedance_kern<T>>();
296
+ return program.get_kernel <local_exceedance_kern<T>>();
297
+ };
298
+
108
299
} // namespace dpcpp_kernels
0 commit comments