Skip to content

Commit 0290954

Browse files
Implemented PR feedback
1 parent 1f3bc74 commit 0290954

File tree

2 files changed

+88
-53
lines changed

2 files changed

+88
-53
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 17 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
#include "pybind11/pybind11.h"
3434
#include "utils/offset_utils.hpp"
35+
#include "utils/sycl_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
3738

@@ -150,35 +151,6 @@ struct ReductionOverGroupWithAtomicFunctor
150151
}
151152
};
152153

153-
template <size_t f = 4>
154-
size_t choose_workgroup_size(const size_t reduction_nelems,
155-
const std::vector<size_t> &sg_sizes)
156-
{
157-
std::vector<size_t> wg_choices;
158-
wg_choices.reserve(f * sg_sizes.size());
159-
160-
for (const auto &sg_size : sg_sizes) {
161-
#pragma unroll
162-
for (size_t i = 1; i <= f; ++i) {
163-
wg_choices.push_back(sg_size * i);
164-
}
165-
}
166-
std::sort(std::begin(wg_choices), std::end(wg_choices));
167-
168-
size_t wg = 1;
169-
for (size_t i = 0; i < wg_choices.size(); ++i) {
170-
if (wg_choices[i] == wg) {
171-
continue;
172-
}
173-
wg = wg_choices[i];
174-
size_t n_groups = ((reduction_nelems + wg - 1) / wg);
175-
if (n_groups == 1)
176-
break;
177-
}
178-
179-
return wg;
180-
}
181-
182154
typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
183155
sycl::queue,
184156
size_t,
@@ -200,6 +172,8 @@ class sum_reduction_over_group_with_atomics_krn;
200172
template <typename T1, typename T2, typename T3>
201173
class sum_reduction_over_group_with_atomics_1d_krn;
202174

175+
using dpctl::tensor::sycl_utils::choose_workgroup_size;
176+
203177
template <typename argTy, typename resTy>
204178
sycl::event sum_reduction_over_group_with_atomics_strided_impl(
205179
sycl::queue exec_q,
@@ -548,13 +522,22 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
548522
(preferrered_reductions_per_wi * wg);
549523
assert(reduction_groups > 1);
550524

551-
resTy *partially_reduced_tmp =
552-
sycl::malloc_device<resTy>(iter_nelems * reduction_groups, exec_q);
525+
size_t second_iter_reduction_groups_ =
526+
(reduction_groups + preferrered_reductions_per_wi * wg - 1) /
527+
(preferrered_reductions_per_wi * wg);
528+
529+
resTy *partially_reduced_tmp = sycl::malloc_device<resTy>(
530+
iter_nelems * (reduction_groups + second_iter_reduction_groups_),
531+
exec_q);
553532
resTy *partially_reduced_tmp2 = nullptr;
554533

555534
if (partially_reduced_tmp == nullptr) {
556535
throw std::runtime_error("Unabled to allocate device_memory");
557536
}
537+
else {
538+
partially_reduced_tmp2 =
539+
partially_reduced_tmp + reduction_groups * iter_nelems;
540+
}
558541

559542
sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) {
560543
cgh.depends_on(depends);
@@ -610,21 +593,6 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
610593
(preferrered_reductions_per_wi * wg);
611594
assert(reduction_groups_ > 1);
612595

613-
if (partially_reduced_tmp2 == nullptr) {
614-
partially_reduced_tmp2 = sycl::malloc_device<resTy>(
615-
iter_nelems * reduction_groups_, exec_q);
616-
617-
if (partially_reduced_tmp2 == nullptr) {
618-
dependent_ev.wait();
619-
sycl::free(partially_reduced_tmp, exec_q);
620-
621-
throw std::runtime_error(
622-
"Unable to allocate device memory");
623-
}
624-
625-
temp2_arg = partially_reduced_tmp2;
626-
}
627-
628596
// keep reducing
629597
sycl::event partial_reduction_ev =
630598
exec_q.submit([&](sycl::handler &cgh) {
@@ -727,13 +695,9 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
727695
cgh.depends_on(final_reduction_ev);
728696
sycl::context ctx = exec_q.get_context();
729697

730-
cgh.host_task(
731-
[ctx, partially_reduced_tmp, partially_reduced_tmp2] {
732-
sycl::free(partially_reduced_tmp, ctx);
733-
if (partially_reduced_tmp2) {
734-
sycl::free(partially_reduced_tmp2, ctx);
735-
}
736-
});
698+
cgh.host_task([ctx, partially_reduced_tmp] {
699+
sycl::free(partially_reduced_tmp, ctx);
700+
});
737701
});
738702

739703
// FIXME: do not return host-task event
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//=== sycl_utils.hpp - Implementation of utilities ------- *-C++-*/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines utilities used for kernel submission.
23+
//===----------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <algorithm>
28+
#include <cstddef>
29+
#include <vector>
30+
31+
namespace dpctl
32+
{
33+
namespace tensor
34+
{
35+
namespace sycl_utils
36+
{
37+
38+
/*! @brief Find the smallest multiple of supported sub-group size larger than
39+
* nelems */
40+
template <size_t f = 4>
41+
size_t choose_workgroup_size(const size_t nelems,
42+
const std::vector<size_t> &sg_sizes)
43+
{
44+
std::vector<size_t> wg_choices;
45+
wg_choices.reserve(f * sg_sizes.size());
46+
47+
for (const auto &sg_size : sg_sizes) {
48+
#pragma unroll
49+
for (size_t i = 1; i <= f; ++i) {
50+
wg_choices.push_back(sg_size * i);
51+
}
52+
}
53+
std::sort(std::begin(wg_choices), std::end(wg_choices));
54+
55+
size_t wg = 1;
56+
for (size_t i = 0; i < wg_choices.size(); ++i) {
57+
if (wg_choices[i] == wg) {
58+
continue;
59+
}
60+
wg = wg_choices[i];
61+
size_t n_groups = ((nelems + wg - 1) / wg);
62+
if (n_groups == 1)
63+
break;
64+
}
65+
66+
return wg;
67+
}
68+
69+
} // namespace sycl_utils
70+
} // namespace tensor
71+
} // namespace dpctl

0 commit comments

Comments
 (0)