Skip to content

Commit 1d5228a

Browse files
Transition sum-reduction from nd_range<2> to nd_range<1>
This improves performance 8x-fold: ``` In [1]: import dpctl.tensor as dpt In [2]: x = dpt.ones((4096, 4096), dtype="f4") In [3]: y = dpt.sum(x, axis=0) In [4]: %time y = dpt.sum(x, axis=0) CPU times: user 2.64 ms, sys: 4.4 ms, total: 7.04 ms Wall time: 10 ms In [5]: %time y = dpt.sum(x, axis=0) CPU times: user 1.93 ms, sys: 3.22 ms, total: 5.16 ms Wall time: 4.74 ms In [6]: %time y = dpt.sum(x, axis=0) CPU times: user 1.7 ms, sys: 2.83 ms, total: 4.53 ms Wall time: 4.1 ms In [7]: %time y = dpt.sum(x, axis=0) CPU times: user 1.98 ms, sys: 3.3 ms, total: 5.28 ms Wall time: 4.7 ms ``` The timing before was around 38ms
1 parent 616c21e commit 1d5228a

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct ReductionOverGroupWithAtomicFunctor
122122
InputOutputIterIndexerT inp_out_iter_indexer_;
123123
InputRedIndexerT inp_reduced_dims_indexer_;
124124
size_t reduction_max_gid_ = 0;
125+
size_t iter_gws_ = 1;
125126
size_t reductions_per_wi = 16;
126127

127128
public:
@@ -133,22 +134,23 @@ struct ReductionOverGroupWithAtomicFunctor
133134
InputOutputIterIndexerT arg_res_iter_indexer,
134135
InputRedIndexerT arg_reduced_dims_indexer,
135136
size_t reduction_size,
137+
size_t iter_gws,
136138
size_t reduction_size_per_wi)
137139
: inp_(data), out_(res), reduction_op_(reduction_op),
138140
identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
139141
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
140-
reduction_max_gid_(reduction_size),
142+
reduction_max_gid_(reduction_size), iter_gws_(iter_gws),
141143
reductions_per_wi(reduction_size_per_wi)
142144
{
143145
}
144146

145-
void operator()(sycl::nd_item<2> it) const
147+
void operator()(sycl::nd_item<1> it) const
146148
{
147-
148-
size_t iter_gid = it.get_global_id(0);
149-
size_t reduction_batch_id = it.get_group(1);
150-
size_t reduction_lid = it.get_local_id(1);
151-
size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg
149+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
150+
const size_t iter_gid = it.get_global_id(0) / red_gws_;
151+
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;
152+
const size_t reduction_lid = it.get_local_id(0);
153+
const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg
152154

153155
// work-items sums over input with indices
154156
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -343,21 +345,21 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
343345
}
344346

345347
auto globalRange =
346-
sycl::range<2>{iter_nelems, reduction_groups * wg};
347-
auto localRange = sycl::range<2>{1, wg};
348+
sycl::range<1>{iter_nelems * reduction_groups * wg};
349+
auto localRange = sycl::range<1>{wg};
348350

349351
using KernelName = class sum_reduction_over_group_with_atomics_krn<
350352
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
351353
ReductionIndexerT>;
352354

353355
cgh.parallel_for<KernelName>(
354-
sycl::nd_range<2>(globalRange, localRange),
356+
sycl::nd_range<1>(globalRange, localRange),
355357
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
356358
InputOutputIterIndexerT,
357359
ReductionIndexerT>(
358360
arg_tp, res_tp, ReductionOpT(), identity_val,
359361
in_out_iter_indexer, reduction_indexer, reduction_nelems,
360-
reductions_per_wi));
362+
iter_nelems, reductions_per_wi));
361363
});
362364

363365
return comp_ev;
@@ -480,21 +482,21 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
480482
}
481483

482484
auto globalRange =
483-
sycl::range<2>{iter_nelems, reduction_groups * wg};
484-
auto localRange = sycl::range<2>{1, wg};
485+
sycl::range<1>{iter_nelems * reduction_groups * wg};
486+
auto localRange = sycl::range<1>{wg};
485487

486488
using KernelName = class sum_reduction_over_group_with_atomics_krn<
487489
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
488490
ReductionIndexerT>;
489491

490492
cgh.parallel_for<KernelName>(
491-
sycl::nd_range<2>(globalRange, localRange),
493+
sycl::nd_range<1>(globalRange, localRange),
492494
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
493495
InputOutputIterIndexerT,
494496
ReductionIndexerT>(
495497
arg_tp, res_tp, ReductionOpT(), identity_val,
496498
in_out_iter_indexer, reduction_indexer, reduction_nelems,
497-
reductions_per_wi));
499+
iter_nelems, reductions_per_wi));
498500
});
499501

500502
return comp_ev;

0 commit comments

Comments
 (0)