Skip to content

Commit 189e8aa

Browse files
Introduce reduce_over_axis0 kernel for contiguous inputs
This achieves additional savings over the prior commit: ``` In [1]: import dpctl.tensor as dpt In [2]: x = dpt.reshape(dpt.asarray(1, dtype="f4")/dpt.square(dpt.arange(1, 1282200*128 + 1, dtype="f4")), (1282200, 128)) In [3]: %time y = dpt.sum(x, axis=0) CPU times: user 136 ms, sys: 9.52 ms, total: 145 ms Wall time: 158 ms In [4]: %time y = dpt.sum(x, axis=0) CPU times: user 18.8 ms, sys: 17.3 ms, total: 36.1 ms Wall time: 42 ms In [5]: %time y = dpt.sum(x, axis=0) CPU times: user 19.2 ms, sys: 16.9 ms, total: 36.1 ms Wall time: 38.4 ms In [6]: %time y = dpt.sum(x, axis=0) CPU times: user 1.69 ms, sys: 35.2 ms, total: 36.9 ms Wall time: 39.4 ms In [7]: quit ``` Prior to this the wall time stood at 49 ms.
1 parent 26b6d6d commit 189e8aa

File tree

2 files changed

+211
-32
lines changed

2 files changed

+211
-32
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,10 @@ template <typename T1, typename T2, typename T3, typename T4, typename T5>
234234
class sum_reduction_seq_contig_krn;
235235

236236
template <typename T1, typename T2, typename T3, typename T4, typename T5>
237-
class sum_reduction_over_group_with_atomics_contig_krn;
237+
class sum_reduction_axis0_over_group_with_atomics_contig_krn;
238+
239+
template <typename T1, typename T2, typename T3, typename T4, typename T5>
240+
class sum_reduction_axis1_over_group_with_atomics_contig_krn;
238241

239242
using dpctl::tensor::sycl_utils::choose_workgroup_size;
240243

@@ -390,7 +393,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)(
390393

391394
/* @brief Reduce rows in a matrix */
392395
template <typename argTy, typename resTy>
393-
sycl::event sum_reduction_over_group_with_atomics_contig_impl(
396+
sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl(
394397
sycl::queue exec_q,
395398
size_t iter_nelems, // number of reductions (num. of rows in a matrix
396399
// when reducing over rows)
@@ -458,11 +461,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
458461
RowsIndexerT, NoOpIndexerT>;
459462
using ReductionIndexerT = NoOpIndexerT;
460463

461-
RowsIndexerT columns_indexer{
464+
RowsIndexerT rows_indexer{
462465
0, static_cast<py::ssize_t>(iter_nelems),
463466
static_cast<py::ssize_t>(reduction_nelems)};
464467
NoOpIndexerT result_indexer{};
465-
InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
468+
InputOutputIterIndexerT in_out_iter_indexer{rows_indexer,
466469
result_indexer};
467470
ReductionIndexerT reduction_indexer{};
468471

@@ -495,7 +498,102 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
495498
auto localRange = sycl::range<1>{wg};
496499

497500
using KernelName =
498-
class sum_reduction_over_group_with_atomics_contig_krn<
501+
class sum_reduction_axis1_over_group_with_atomics_contig_krn<
502+
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
503+
ReductionIndexerT>;
504+
505+
cgh.parallel_for<KernelName>(
506+
sycl::nd_range<1>(globalRange, localRange),
507+
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
508+
InputOutputIterIndexerT,
509+
ReductionIndexerT>(
510+
arg_tp, res_tp, ReductionOpT(), identity_val,
511+
in_out_iter_indexer, reduction_indexer, reduction_nelems,
512+
iter_nelems, reductions_per_wi));
513+
});
514+
515+
return comp_ev;
516+
}
517+
}
518+
519+
/* @brief Reduce rows in a matrix */
520+
template <typename argTy, typename resTy>
521+
sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl(
522+
sycl::queue exec_q,
523+
size_t iter_nelems, // number of reductions (num. of cols in a matrix
524+
// when reducing over cols)
525+
size_t reduction_nelems, // size of each reduction (length of cols, i.e.
526+
// number of rows)
527+
const char *arg_cp,
528+
char *res_cp,
529+
py::ssize_t iter_arg_offset,
530+
py::ssize_t iter_res_offset,
531+
py::ssize_t reduction_arg_offset,
532+
const std::vector<sycl::event> &depends)
533+
{
534+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
535+
iter_arg_offset + reduction_arg_offset;
536+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
537+
538+
using ReductionOpT = sycl::plus<resTy>;
539+
constexpr resTy identity_val = resTy{0};
540+
541+
const sycl::device &d = exec_q.get_device();
542+
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
543+
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
544+
545+
{
546+
sycl::event res_init_ev = exec_q.fill<resTy>(
547+
res_tp, resTy(identity_val), iter_nelems, depends);
548+
549+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
550+
cgh.depends_on(res_init_ev);
551+
552+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
553+
using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
554+
using InputOutputIterIndexerT =
555+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
556+
NoOpIndexerT, NoOpIndexerT>;
557+
using ReductionIndexerT = ColsIndexerT;
558+
559+
NoOpIndexerT columns_indexer{};
560+
NoOpIndexerT result_indexer{};
561+
InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
562+
result_indexer};
563+
ReductionIndexerT reduction_indexer{
564+
0, /* size */ static_cast<py::ssize_t>(reduction_nelems),
565+
/* step */ static_cast<py::ssize_t>(iter_nelems)};
566+
567+
constexpr size_t preferrered_reductions_per_wi = 8;
568+
size_t reductions_per_wi =
569+
(reduction_nelems < preferrered_reductions_per_wi * wg)
570+
? std::max<size_t>(1, (reduction_nelems + wg - 1) / wg)
571+
: preferrered_reductions_per_wi;
572+
573+
size_t reduction_groups =
574+
(reduction_nelems + reductions_per_wi * wg - 1) /
575+
(reductions_per_wi * wg);
576+
577+
if (reduction_groups > 1) {
578+
const size_t &max_wg =
579+
d.get_info<sycl::info::device::max_work_group_size>();
580+
581+
if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
582+
wg = max_wg;
583+
reductions_per_wi =
584+
std::max<size_t>(1, (reduction_nelems + wg - 1) / wg);
585+
reduction_groups =
586+
(reduction_nelems + reductions_per_wi * wg - 1) /
587+
(reductions_per_wi * wg);
588+
}
589+
}
590+
591+
auto globalRange =
592+
sycl::range<1>{iter_nelems * reduction_groups * wg};
593+
auto localRange = sycl::range<1>{wg};
594+
595+
using KernelName =
596+
class sum_reduction_axis0_over_group_with_atomics_contig_krn<
499597
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
500598
ReductionIndexerT>;
501599

@@ -1075,15 +1173,34 @@ struct SumOverAxisTempsStridedFactory
10751173
};
10761174

10771175
template <typename fnT, typename srcTy, typename dstTy>
1078-
struct SumOverAxisAtomicContigFactory
1176+
struct SumOverAxis1AtomicContigFactory
1177+
{
1178+
fnT get() const
1179+
{
1180+
if constexpr (TypePairSupportDataForSumReductionAtomic<
1181+
srcTy, dstTy>::is_defined)
1182+
{
1183+
return dpctl::tensor::kernels::
1184+
sum_reduction_axis1_over_group_with_atomics_contig_impl<srcTy,
1185+
dstTy>;
1186+
}
1187+
else {
1188+
return nullptr;
1189+
}
1190+
}
1191+
};
1192+
1193+
template <typename fnT, typename srcTy, typename dstTy>
1194+
struct SumOverAxis0AtomicContigFactory
10791195
{
10801196
fnT get() const
10811197
{
10821198
if constexpr (TypePairSupportDataForSumReductionAtomic<
10831199
srcTy, dstTy>::is_defined)
10841200
{
10851201
return dpctl::tensor::kernels::
1086-
sum_reduction_over_group_with_atomics_contig_impl<srcTy, dstTy>;
1202+
sum_reduction_axis0_over_group_with_atomics_contig_impl<srcTy,
1203+
dstTy>;
10871204
}
10881205
else {
10891206
return nullptr;

dpctl/tensor/libtensor/source/sum_reductions.cpp

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,11 @@ static sum_reduction_strided_impl_fn_ptr
8888

8989
using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr;
9090
static sum_reduction_contig_impl_fn_ptr
91-
sum_over_axis_contig_atomic_dispatch_table[td_ns::num_types]
92-
[td_ns::num_types];
91+
sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types]
92+
[td_ns::num_types];
93+
static sum_reduction_contig_impl_fn_ptr
94+
sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types]
95+
[td_ns::num_types];
9396

9497
std::pair<sycl::event, sycl::event> py_sum_over_axis(
9598
dpctl::tensor::usm_ndarray src,
@@ -194,8 +197,30 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
194197
if ((is_src_c_contig && is_dst_c_contig) ||
195198
(is_src_f_contig && dst_nelems == 1))
196199
{
197-
auto fn = sum_over_axis_contig_atomic_dispatch_table[src_typeid]
198-
[dst_typeid];
200+
auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid]
201+
[dst_typeid];
202+
if (fn != nullptr) {
203+
size_t iter_nelems = dst_nelems;
204+
205+
constexpr py::ssize_t zero_offset = 0;
206+
207+
sycl::event sum_over_axis_contig_ev =
208+
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
209+
dst.get_data(),
210+
zero_offset, // iteration_src_offset
211+
zero_offset, // iteration_dst_offset
212+
zero_offset, // reduction_src_offset
213+
depends);
214+
215+
sycl::event keep_args_event = dpctl::utils::keep_args_alive(
216+
exec_q, {src, dst}, {sum_over_axis_contig_ev});
217+
218+
return std::make_pair(keep_args_event, sum_over_axis_contig_ev);
219+
}
220+
}
221+
else if (is_src_f_contig & is_dst_c_contig) {
222+
auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid]
223+
[dst_typeid];
199224
if (fn != nullptr) {
200225
size_t iter_nelems = dst_nelems;
201226

@@ -271,27 +296,58 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
271296
iteration_src_offset, iteration_dst_offset);
272297
}
273298

274-
if (supports_atomics && (reduction_nd == 1) &&
275-
(simplified_reduction_src_strides[0] == 1) && (iteration_nd == 1) &&
276-
((simplified_iteration_shape[0] == 1) ||
277-
((simplified_iteration_dst_strides[0] == 1) &&
278-
(static_cast<size_t>(simplified_iteration_src_strides[0]) ==
279-
reduction_nelems))))
280-
{
281-
auto fn =
282-
sum_over_axis_contig_atomic_dispatch_table[src_typeid][dst_typeid];
283-
if (fn != nullptr) {
284-
size_t iter_nelems = dst_nelems;
299+
if (supports_atomics && (reduction_nd == 1) && (iteration_nd == 1)) {
300+
bool mat_reduce_over_axis1 = false;
301+
bool mat_reduce_over_axis0 = false;
302+
bool array_reduce_all_elems = false;
303+
size_t iter_nelems = dst_nelems;
304+
305+
if (simplified_reduction_src_strides[0] == 1) {
306+
array_reduce_all_elems = (simplified_iteration_shape[0] == 1);
307+
mat_reduce_over_axis1 =
308+
(simplified_iteration_dst_strides[0] == 1) &&
309+
(static_cast<size_t>(simplified_iteration_src_strides[0]) ==
310+
reduction_nelems);
311+
}
312+
else if (static_cast<size_t>(simplified_reduction_src_strides[0]) ==
313+
iter_nelems)
314+
{
315+
mat_reduce_over_axis0 =
316+
(simplified_iteration_dst_strides[0] == 1) &&
317+
(simplified_iteration_src_strides[0] == 1);
318+
}
319+
320+
if (mat_reduce_over_axis1 || array_reduce_all_elems) {
321+
auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid]
322+
[dst_typeid];
323+
if (fn != nullptr) {
324+
sycl::event sum_over_axis1_contig_ev =
325+
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
326+
dst.get_data(), iteration_src_offset,
327+
iteration_dst_offset, reduction_src_offset, depends);
285328

286-
sycl::event sum_over_axis_contig_ev =
287-
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
288-
dst.get_data(), iteration_src_offset, iteration_dst_offset,
289-
reduction_src_offset, depends);
329+
sycl::event keep_args_event = dpctl::utils::keep_args_alive(
330+
exec_q, {src, dst}, {sum_over_axis1_contig_ev});
331+
332+
return std::make_pair(keep_args_event,
333+
sum_over_axis1_contig_ev);
334+
}
335+
}
336+
else if (mat_reduce_over_axis0) {
337+
auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid]
338+
[dst_typeid];
339+
if (fn != nullptr) {
340+
sycl::event sum_over_axis0_contig_ev =
341+
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
342+
dst.get_data(), iteration_src_offset,
343+
iteration_dst_offset, reduction_src_offset, depends);
290344

291-
sycl::event keep_args_event = dpctl::utils::keep_args_alive(
292-
exec_q, {src, dst}, {sum_over_axis_contig_ev});
345+
sycl::event keep_args_event = dpctl::utils::keep_args_alive(
346+
exec_q, {src, dst}, {sum_over_axis0_contig_ev});
293347

294-
return std::make_pair(keep_args_event, sum_over_axis_contig_ev);
348+
return std::make_pair(keep_args_event,
349+
sum_over_axis0_contig_ev);
350+
}
295351
}
296352
}
297353

@@ -451,11 +507,17 @@ void populate_sum_over_axis_dispatch_table(void)
451507
dtb2;
452508
dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table);
453509

454-
using dpctl::tensor::kernels::SumOverAxisAtomicContigFactory;
510+
using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory;
455511
DispatchTableBuilder<sum_reduction_contig_impl_fn_ptr,
456-
SumOverAxisAtomicContigFactory, num_types>
512+
SumOverAxis1AtomicContigFactory, num_types>
457513
dtb3;
458-
dtb3.populate_dispatch_table(sum_over_axis_contig_atomic_dispatch_table);
514+
dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table);
515+
516+
using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory;
517+
DispatchTableBuilder<sum_reduction_contig_impl_fn_ptr,
518+
SumOverAxis0AtomicContigFactory, num_types>
519+
dtb4;
520+
dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table);
459521
}
460522

461523
namespace py = pybind11;

0 commit comments

Comments
 (0)