@@ -234,7 +234,10 @@ template <typename T1, typename T2, typename T3, typename T4, typename T5>
234
234
class sum_reduction_seq_contig_krn ;
235
235
236
236
template <typename T1, typename T2, typename T3, typename T4, typename T5>
237
- class sum_reduction_over_group_with_atomics_contig_krn ;
237
+ class sum_reduction_axis0_over_group_with_atomics_contig_krn ;
238
+
239
+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
240
+ class sum_reduction_axis1_over_group_with_atomics_contig_krn ;
238
241
239
242
using dpctl::tensor::sycl_utils::choose_workgroup_size;
240
243
@@ -390,7 +393,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)(
390
393
391
394
/* @brief Reduce rows in a matrix */
392
395
template <typename argTy, typename resTy>
393
- sycl::event sum_reduction_over_group_with_atomics_contig_impl (
396
+ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl (
394
397
sycl::queue exec_q,
395
398
size_t iter_nelems, // number of reductions (num. of rows in a matrix
396
399
// when reducing over rows)
@@ -458,11 +461,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
458
461
RowsIndexerT, NoOpIndexerT>;
459
462
using ReductionIndexerT = NoOpIndexerT;
460
463
461
- RowsIndexerT columns_indexer {
464
+ RowsIndexerT rows_indexer {
462
465
0 , static_cast <py::ssize_t >(iter_nelems),
463
466
static_cast <py::ssize_t >(reduction_nelems)};
464
467
NoOpIndexerT result_indexer{};
465
- InputOutputIterIndexerT in_out_iter_indexer{columns_indexer ,
468
+ InputOutputIterIndexerT in_out_iter_indexer{rows_indexer ,
466
469
result_indexer};
467
470
ReductionIndexerT reduction_indexer{};
468
471
@@ -495,7 +498,102 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
495
498
auto localRange = sycl::range<1 >{wg};
496
499
497
500
using KernelName =
498
- class sum_reduction_over_group_with_atomics_contig_krn <
501
+ class sum_reduction_axis1_over_group_with_atomics_contig_krn <
502
+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
503
+ ReductionIndexerT>;
504
+
505
+ cgh.parallel_for <KernelName>(
506
+ sycl::nd_range<1 >(globalRange, localRange),
507
+ ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
508
+ InputOutputIterIndexerT,
509
+ ReductionIndexerT>(
510
+ arg_tp, res_tp, ReductionOpT (), identity_val,
511
+ in_out_iter_indexer, reduction_indexer, reduction_nelems,
512
+ iter_nelems, reductions_per_wi));
513
+ });
514
+
515
+ return comp_ev;
516
+ }
517
+ }
518
+
519
+ /* @brief Reduce rows in a matrix */
520
+ template <typename argTy, typename resTy>
521
+ sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl (
522
+ sycl::queue exec_q,
523
+ size_t iter_nelems, // number of reductions (num. of cols in a matrix
524
+ // when reducing over cols)
525
+ size_t reduction_nelems, // size of each reduction (length of cols, i.e.
526
+ // number of rows)
527
+ const char *arg_cp,
528
+ char *res_cp,
529
+ py::ssize_t iter_arg_offset,
530
+ py::ssize_t iter_res_offset,
531
+ py::ssize_t reduction_arg_offset,
532
+ const std::vector<sycl::event> &depends)
533
+ {
534
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
535
+ iter_arg_offset + reduction_arg_offset;
536
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
537
+
538
+ using ReductionOpT = sycl::plus<resTy>;
539
+ constexpr resTy identity_val = resTy{0 };
540
+
541
+ const sycl::device &d = exec_q.get_device ();
542
+ const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
543
+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
544
+
545
+ {
546
+ sycl::event res_init_ev = exec_q.fill <resTy>(
547
+ res_tp, resTy (identity_val), iter_nelems, depends);
548
+
549
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
550
+ cgh.depends_on (res_init_ev);
551
+
552
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
553
+ using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
554
+ using InputOutputIterIndexerT =
555
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
556
+ NoOpIndexerT, NoOpIndexerT>;
557
+ using ReductionIndexerT = ColsIndexerT;
558
+
559
+ NoOpIndexerT columns_indexer{};
560
+ NoOpIndexerT result_indexer{};
561
+ InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
562
+ result_indexer};
563
+ ReductionIndexerT reduction_indexer{
564
+ 0 , /* size */ static_cast <py::ssize_t >(reduction_nelems),
565
+ /* step */ static_cast <py::ssize_t >(iter_nelems)};
566
+
567
+ constexpr size_t preferrered_reductions_per_wi = 8 ;
568
+ size_t reductions_per_wi =
569
+ (reduction_nelems < preferrered_reductions_per_wi * wg)
570
+ ? std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg)
571
+ : preferrered_reductions_per_wi;
572
+
573
+ size_t reduction_groups =
574
+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
575
+ (reductions_per_wi * wg);
576
+
577
+ if (reduction_groups > 1 ) {
578
+ const size_t &max_wg =
579
+ d.get_info <sycl::info::device::max_work_group_size>();
580
+
581
+ if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
582
+ wg = max_wg;
583
+ reductions_per_wi =
584
+ std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg);
585
+ reduction_groups =
586
+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
587
+ (reductions_per_wi * wg);
588
+ }
589
+ }
590
+
591
+ auto globalRange =
592
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
593
+ auto localRange = sycl::range<1 >{wg};
594
+
595
+ using KernelName =
596
+ class sum_reduction_axis0_over_group_with_atomics_contig_krn <
499
597
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
500
598
ReductionIndexerT>;
501
599
@@ -1075,15 +1173,34 @@ struct SumOverAxisTempsStridedFactory
1075
1173
};
1076
1174
1077
1175
template <typename fnT, typename srcTy, typename dstTy>
1078
- struct SumOverAxisAtomicContigFactory
1176
+ struct SumOverAxis1AtomicContigFactory
1177
+ {
1178
+ fnT get () const
1179
+ {
1180
+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1181
+ srcTy, dstTy>::is_defined)
1182
+ {
1183
+ return dpctl::tensor::kernels::
1184
+ sum_reduction_axis1_over_group_with_atomics_contig_impl<srcTy,
1185
+ dstTy>;
1186
+ }
1187
+ else {
1188
+ return nullptr ;
1189
+ }
1190
+ }
1191
+ };
1192
+
1193
+ template <typename fnT, typename srcTy, typename dstTy>
1194
+ struct SumOverAxis0AtomicContigFactory
1079
1195
{
1080
1196
fnT get () const
1081
1197
{
1082
1198
if constexpr (TypePairSupportDataForSumReductionAtomic<
1083
1199
srcTy, dstTy>::is_defined)
1084
1200
{
1085
1201
return dpctl::tensor::kernels::
1086
- sum_reduction_over_group_with_atomics_contig_impl<srcTy, dstTy>;
1202
+ sum_reduction_axis0_over_group_with_atomics_contig_impl<srcTy,
1203
+ dstTy>;
1087
1204
}
1088
1205
else {
1089
1206
return nullptr ;
0 commit comments