Skip to content

Commit 9f54428

Browse files
Specify name for the atomic reduction initialization kernel
1 parent 63b2799 commit 9f54428

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
232232
template <typename T1, typename T2, typename T3, typename T4, typename T5>
233233
class sum_reduction_over_group_with_atomics_krn;
234234

235+
template <typename T1, typename T2>
236+
class sum_reduction_over_group_with_atomics_init_krn;
237+
235238
template <typename T1, typename T2, typename T3, typename T4, typename T5>
236239
class sum_reduction_seq_strided_krn;
237240

@@ -305,13 +308,16 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
305308
iter_shape_and_strides + 2 * iter_nd;
306309
IndexerT res_indexer(iter_nd, iter_res_offset, res_shape,
307310
res_strides);
308-
311+
using InitKernelName =
312+
class sum_reduction_over_group_with_atomics_init_krn<resTy,
313+
argTy>;
309314
cgh.depends_on(depends);
310315

311-
cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) {
312-
auto res_offset = res_indexer(id[0]);
313-
res_tp[res_offset] = identity_val;
314-
});
316+
cgh.parallel_for<InitKernelName>(
317+
sycl::range<1>(iter_nelems), [=](sycl::id<1> id) {
318+
auto res_offset = res_indexer(id[0]);
319+
res_tp[res_offset] = identity_val;
320+
});
315321
});
316322

317323
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {

0 commit comments

Comments
 (0)