Skip to content

Reduction performance #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 122 additions & 52 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ struct ReductionOverGroupWithAtomicFunctor

void operator()(sycl::nd_item<1> it) const
{
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
const size_t iter_gid = it.get_global_id(0) / red_gws_;
const size_t reduction_batch_id = get_reduction_batch_id(it);
const size_t iter_gid = it.get_group(0) % iter_gws_;
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;

const size_t reduction_lid = it.get_local_id(0);
const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg

Expand Down Expand Up @@ -204,14 +204,6 @@ struct ReductionOverGroupWithAtomicFunctor
}
}
}

private:
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
{
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
return reduction_batch_id;
}
};

typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
Expand Down Expand Up @@ -241,6 +233,12 @@ class sum_reduction_seq_strided_krn;
template <typename T1, typename T2, typename T3, typename T4, typename T5>
class sum_reduction_seq_contig_krn;

template <typename T1, typename T2, typename T3, typename T4, typename T5>
class sum_reduction_axis0_over_group_with_atomics_contig_krn;

template <typename T1, typename T2, typename T3, typename T4, typename T5>
class sum_reduction_axis1_over_group_with_atomics_contig_krn;

using dpctl::tensor::sycl_utils::choose_workgroup_size;

template <typename argTy, typename resTy>
Expand Down Expand Up @@ -344,20 +342,6 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
(reduction_nelems + reductions_per_wi * wg - 1) /
(reductions_per_wi * wg);

if (reduction_groups > 1) {
const size_t &max_wg =
d.get_info<sycl::info::device::max_work_group_size>();

if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
wg = max_wg;
reductions_per_wi =
std::max<size_t>(1, (reduction_nelems + wg - 1) / wg);
reduction_groups =
(reduction_nelems + reductions_per_wi * wg - 1) /
(reductions_per_wi * wg);
}
}

auto globalRange =
sycl::range<1>{iter_nelems * reduction_groups * wg};
auto localRange = sycl::range<1>{wg};
Expand Down Expand Up @@ -395,7 +379,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)(

/* @brief Reduce rows in a matrix */
template <typename argTy, typename resTy>
sycl::event sum_reduction_over_group_with_atomics_contig_impl(
sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl(
sycl::queue exec_q,
size_t iter_nelems, // number of reductions (num. of rows in a matrix
// when reducing over rows)
Expand All @@ -417,7 +401,7 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(

const sycl::device &d = exec_q.get_device();
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<2>(reduction_nelems, sg_sizes);
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
Expand Down Expand Up @@ -463,11 +447,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
RowsIndexerT, NoOpIndexerT>;
using ReductionIndexerT = NoOpIndexerT;

RowsIndexerT columns_indexer{
RowsIndexerT rows_indexer{
0, static_cast<py::ssize_t>(iter_nelems),
static_cast<py::ssize_t>(reduction_nelems)};
NoOpIndexerT result_indexer{};
InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
InputOutputIterIndexerT in_out_iter_indexer{rows_indexer,
result_indexer};
ReductionIndexerT reduction_indexer{};

Expand All @@ -481,27 +465,95 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
(reduction_nelems + reductions_per_wi * wg - 1) /
(reductions_per_wi * wg);

if (reduction_groups > 1) {
const size_t &max_wg =
d.get_info<sycl::info::device::max_work_group_size>();

if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
wg = max_wg;
reductions_per_wi =
std::max<size_t>(1, (reduction_nelems + wg - 1) / wg);
reduction_groups =
(reduction_nelems + reductions_per_wi * wg - 1) /
(reductions_per_wi * wg);
}
}
auto globalRange =
sycl::range<1>{iter_nelems * reduction_groups * wg};
auto localRange = sycl::range<1>{wg};

using KernelName =
class sum_reduction_axis1_over_group_with_atomics_contig_krn<
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
ReductionIndexerT>;

cgh.parallel_for<KernelName>(
sycl::nd_range<1>(globalRange, localRange),
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val,
in_out_iter_indexer, reduction_indexer, reduction_nelems,
iter_nelems, reductions_per_wi));
});

return comp_ev;
}
}

/* @brief Reduce rows in a matrix */
template <typename argTy, typename resTy>
sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl(
sycl::queue exec_q,
size_t iter_nelems, // number of reductions (num. of cols in a matrix
// when reducing over cols)
size_t reduction_nelems, // size of each reduction (length of cols, i.e.
// number of rows)
const char *arg_cp,
char *res_cp,
py::ssize_t iter_arg_offset,
py::ssize_t iter_res_offset,
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

using ReductionOpT = sycl::plus<resTy>;
constexpr resTy identity_val = resTy{0};

const sycl::device &d = exec_q.get_device();
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

{
sycl::event res_init_ev = exec_q.fill<resTy>(
res_tp, resTy(identity_val), iter_nelems, depends);

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(res_init_ev);

using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
NoOpIndexerT, NoOpIndexerT>;
using ReductionIndexerT = ColsIndexerT;

NoOpIndexerT columns_indexer{};
NoOpIndexerT result_indexer{};
InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
result_indexer};
ReductionIndexerT reduction_indexer{
0, /* size */ static_cast<py::ssize_t>(reduction_nelems),
/* step */ static_cast<py::ssize_t>(iter_nelems)};

constexpr size_t preferrered_reductions_per_wi = 8;
size_t reductions_per_wi =
(reduction_nelems < preferrered_reductions_per_wi * wg)
? std::max<size_t>(1, (reduction_nelems + wg - 1) / wg)
: preferrered_reductions_per_wi;

size_t reduction_groups =
(reduction_nelems + reductions_per_wi * wg - 1) /
(reductions_per_wi * wg);

auto globalRange =
sycl::range<1>{iter_nelems * reduction_groups * wg};
auto localRange = sycl::range<1>{wg};

using KernelName = class sum_reduction_over_group_with_atomics_krn<
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
ReductionIndexerT>;
using KernelName =
class sum_reduction_axis0_over_group_with_atomics_contig_krn<
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
ReductionIndexerT>;

cgh.parallel_for<KernelName>(
sycl::nd_range<1>(globalRange, localRange),
Expand Down Expand Up @@ -558,14 +610,13 @@ struct ReductionOverGroupNoAtomicFunctor

void operator()(sycl::nd_item<1> it) const
{

const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
const size_t iter_gid = it.get_global_id(0) / red_gws_;
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
const size_t reduction_lid = it.get_local_id(0);
const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg

const size_t iter_gid = it.get_group(0) % iter_gws_;
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;

// work-items sums over input with indices
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
// + reduction_lid
Expand Down Expand Up @@ -1079,15 +1130,34 @@ struct SumOverAxisTempsStridedFactory
};

template <typename fnT, typename srcTy, typename dstTy>
struct SumOverAxisAtomicContigFactory
struct SumOverAxis1AtomicContigFactory
{
fnT get() const
{
if constexpr (TypePairSupportDataForSumReductionAtomic<
srcTy, dstTy>::is_defined)
{
return dpctl::tensor::kernels::
sum_reduction_axis1_over_group_with_atomics_contig_impl<srcTy,
dstTy>;
}
else {
return nullptr;
}
}
};

template <typename fnT, typename srcTy, typename dstTy>
struct SumOverAxis0AtomicContigFactory
{
fnT get() const
{
if constexpr (TypePairSupportDataForSumReductionAtomic<
srcTy, dstTy>::is_defined)
{
return dpctl::tensor::kernels::
sum_reduction_over_group_with_atomics_contig_impl<srcTy, dstTy>;
sum_reduction_axis0_over_group_with_atomics_contig_impl<srcTy,
dstTy>;
}
else {
return nullptr;
Expand Down
Loading