Skip to content

Commit 32d4419

Browse files
Change WG traversal pattern in tree reduction kernel
Made changes similar to those made in kernels for atomic reduction. The WG's location change along iteration dimension the fastest (previously along reduction dimension the fastest). Due to this change performance of reduction increases 7-8x: ``` In [1]: import dpctl.tensor as dpt In [2]: x = dpt.reshape(dpt.asarray(1, dtype="f4")/dpt.square(dpt.arange(1, 1282200*128 + 1, dtype="f2")), (1282200, 128)) In [3]: %time y = dpt.sum(x, axis=0, dtype="f2") CPU times: user 284 ms, sys: 3.68 ms, total: 287 ms Wall time: 316 ms In [4]: %time y = dpt.sum(x, axis=0, dtype="f2") CPU times: user 18.6 ms, sys: 18.9 ms, total: 37.5 ms Wall time: 43 ms In [5]: quit ``` While in the main branch: ``` In [1]: import dpctl.tensor as dpt In [2]: x = dpt.reshape(dpt.asarray(1, dtype="f4")/dpt.square(dpt.arange(1, 1282200*128 + 1, dtype="f2")), (1282200, 128)) In [3]: %time y = dpt.sum(x, axis=0, dtype="f2") CPU times: user 440 ms, sys: 129 ms, total: 569 ms Wall time: 514 ms In [4]: %time y = dpt.sum(x, axis=0, dtype="f2") CPU times: user 142 ms, sys: 159 ms, total: 301 ms Wall time: 325 ms In [5]: %time y = dpt.sum(x, axis=0, dtype="f2") CPU times: user 142 ms, sys: 154 ms, total: 296 ms Wall time: 325 ms In [6]: quit ```
1 parent c79445f commit 32d4419

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -610,14 +610,13 @@ struct ReductionOverGroupNoAtomicFunctor
610610

611611
void operator()(sycl::nd_item<1> it) const
612612
{
613-
614-
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
615-
const size_t iter_gid = it.get_global_id(0) / red_gws_;
616-
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
617-
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
618613
const size_t reduction_lid = it.get_local_id(0);
619614
const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg
620615

616+
const size_t iter_gid = it.get_group(0) % iter_gws_;
617+
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;
618+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
619+
621620
// work-items sums over input with indices
622621
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
623622
// + reduction_lid

0 commit comments

Comments
 (0)