@@ -232,6 +232,9 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
232
232
template <typename T1, typename T2, typename T3, typename T4, typename T5>
233
233
class sum_reduction_over_group_with_atomics_krn ;
234
234
235
+ template <typename T1, typename T2>
236
+ class sum_reduction_over_group_with_atomics_init_krn ;
237
+
235
238
template <typename T1, typename T2, typename T3, typename T4, typename T5>
236
239
class sum_reduction_seq_strided_krn ;
237
240
@@ -305,13 +308,16 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
305
308
iter_shape_and_strides + 2 * iter_nd;
306
309
IndexerT res_indexer (iter_nd, iter_res_offset, res_shape,
307
310
res_strides);
308
-
311
+ using InitKernelName =
312
+ class sum_reduction_over_group_with_atomics_init_krn <resTy,
313
+ argTy>;
309
314
cgh.depends_on (depends);
310
315
311
- cgh.parallel_for (sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
312
- auto res_offset = res_indexer (id[0 ]);
313
- res_tp[res_offset] = identity_val;
314
- });
316
+ cgh.parallel_for <InitKernelName>(
317
+ sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
318
+ auto res_offset = res_indexer (id[0 ]);
319
+ res_tp[res_offset] = identity_val;
320
+ });
315
321
});
316
322
317
323
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
0 commit comments