32
32
33
33
#include " pybind11/pybind11.h"
34
34
#include " utils/offset_utils.hpp"
35
+ #include " utils/sycl_utils.hpp"
35
36
#include " utils/type_dispatch.hpp"
36
37
#include " utils/type_utils.hpp"
37
38
@@ -150,35 +151,6 @@ struct ReductionOverGroupWithAtomicFunctor
150
151
}
151
152
};
152
153
153
- template <size_t f = 4 >
154
- size_t choose_workgroup_size (const size_t reduction_nelems,
155
- const std::vector<size_t > &sg_sizes)
156
- {
157
- std::vector<size_t > wg_choices;
158
- wg_choices.reserve (f * sg_sizes.size ());
159
-
160
- for (const auto &sg_size : sg_sizes) {
161
- #pragma unroll
162
- for (size_t i = 1 ; i <= f; ++i) {
163
- wg_choices.push_back (sg_size * i);
164
- }
165
- }
166
- std::sort (std::begin (wg_choices), std::end (wg_choices));
167
-
168
- size_t wg = 1 ;
169
- for (size_t i = 0 ; i < wg_choices.size (); ++i) {
170
- if (wg_choices[i] == wg) {
171
- continue ;
172
- }
173
- wg = wg_choices[i];
174
- size_t n_groups = ((reduction_nelems + wg - 1 ) / wg);
175
- if (n_groups == 1 )
176
- break ;
177
- }
178
-
179
- return wg;
180
- }
181
-
182
154
typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
183
155
sycl::queue,
184
156
size_t ,
@@ -200,6 +172,8 @@ class sum_reduction_over_group_with_atomics_krn;
200
172
template <typename T1, typename T2, typename T3>
201
173
class sum_reduction_over_group_with_atomics_1d_krn ;
202
174
175
+ using dpctl::tensor::sycl_utils::choose_workgroup_size;
176
+
203
177
template <typename argTy, typename resTy>
204
178
sycl::event sum_reduction_over_group_with_atomics_strided_impl (
205
179
sycl::queue exec_q,
@@ -548,13 +522,22 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
548
522
(preferrered_reductions_per_wi * wg);
549
523
assert (reduction_groups > 1 );
550
524
551
- resTy *partially_reduced_tmp =
552
- sycl::malloc_device<resTy>(iter_nelems * reduction_groups, exec_q);
525
+ size_t second_iter_reduction_groups_ =
526
+ (reduction_groups + preferrered_reductions_per_wi * wg - 1 ) /
527
+ (preferrered_reductions_per_wi * wg);
528
+
529
+ resTy *partially_reduced_tmp = sycl::malloc_device<resTy>(
530
+ iter_nelems * (reduction_groups + second_iter_reduction_groups_),
531
+ exec_q);
553
532
resTy *partially_reduced_tmp2 = nullptr ;
554
533
555
534
if (partially_reduced_tmp == nullptr ) {
556
535
throw std::runtime_error (" Unabled to allocate device_memory" );
557
536
}
537
+ else {
538
+ partially_reduced_tmp2 =
539
+ partially_reduced_tmp + reduction_groups * iter_nelems;
540
+ }
558
541
559
542
sycl::event first_reduction_ev = exec_q.submit ([&](sycl::handler &cgh) {
560
543
cgh.depends_on (depends);
@@ -610,21 +593,6 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
610
593
(preferrered_reductions_per_wi * wg);
611
594
assert (reduction_groups_ > 1 );
612
595
613
- if (partially_reduced_tmp2 == nullptr ) {
614
- partially_reduced_tmp2 = sycl::malloc_device<resTy>(
615
- iter_nelems * reduction_groups_, exec_q);
616
-
617
- if (partially_reduced_tmp2 == nullptr ) {
618
- dependent_ev.wait ();
619
- sycl::free (partially_reduced_tmp, exec_q);
620
-
621
- throw std::runtime_error (
622
- " Unable to allocate device memory" );
623
- }
624
-
625
- temp2_arg = partially_reduced_tmp2;
626
- }
627
-
628
596
// keep reducing
629
597
sycl::event partial_reduction_ev =
630
598
exec_q.submit ([&](sycl::handler &cgh) {
@@ -727,13 +695,9 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
727
695
cgh.depends_on (final_reduction_ev);
728
696
sycl::context ctx = exec_q.get_context ();
729
697
730
- cgh.host_task (
731
- [ctx, partially_reduced_tmp, partially_reduced_tmp2] {
732
- sycl::free (partially_reduced_tmp, ctx);
733
- if (partially_reduced_tmp2) {
734
- sycl::free (partially_reduced_tmp2, ctx);
735
- }
736
- });
698
+ cgh.host_task ([ctx, partially_reduced_tmp] {
699
+ sycl::free (partially_reduced_tmp, ctx);
700
+ });
737
701
});
738
702
739
703
// FIXME: do not return host-task event
0 commit comments