Skip to content

Commit 115c1a0

Browse files
authored
[SYCL] Support lambda functions passed to reduction (#2190)
Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent ce915ef commit 115c1a0

File tree

3 files changed

+134
-79
lines changed

3 files changed

+134
-79
lines changed

sycl/include/CL/sycl/intel/reduction.hpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,17 @@ using IsKnownIdentityOp =
147147
template <typename T, class BinaryOperation, typename Subst = void>
148148
class reducer {
149149
public:
150-
reducer(const T &Identity) : MValue(Identity), MIdentity(Identity) {}
151-
void combine(const T &Partial) {
152-
BinaryOperation BOp;
153-
MValue = BOp(MValue, Partial);
154-
}
150+
reducer(const T &Identity, BinaryOperation BOp)
151+
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
152+
void combine(const T &Partial) { MValue = MBinaryOp(MValue, Partial); }
155153

156154
T getIdentity() const { return MIdentity; }
157155

158156
T MValue;
159157

160158
private:
161159
const T MIdentity;
160+
BinaryOperation MBinaryOp;
162161
};
163162

164163
/// Specialization of the generic class 'reducer'. It is used for reductions
@@ -183,7 +182,7 @@ class reducer<T, BinaryOperation,
183182
enable_if_t<IsKnownIdentityOp<T, BinaryOperation>::value>> {
184183
public:
185184
reducer() : MValue(getIdentity()) {}
186-
reducer(const T &) : MValue(getIdentity()) {}
185+
reducer(const T &, BinaryOperation) : MValue(getIdentity()) {}
187186

188187
void combine(const T &Partial) {
189188
BinaryOperation BOp;
@@ -405,7 +404,7 @@ class reduction_impl {
405404
template <
406405
typename _T = T, class _BinaryOperation = BinaryOperation,
407406
enable_if_t<IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
408-
reduction_impl(accessor_type &Acc, const T &Identity)
407+
reduction_impl(accessor_type &Acc, const T &Identity, BinaryOperation)
409408
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
410409
&Acc)),
411410
MIdentity(getIdentity()) {
@@ -431,10 +430,10 @@ class reduction_impl {
431430
template <
432431
typename _T = T, class _BinaryOperation = BinaryOperation,
433432
enable_if_t<!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
434-
reduction_impl(accessor_type &Acc, const T &Identity)
433+
reduction_impl(accessor_type &Acc, const T &Identity, BinaryOperation BOp)
435434
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
436435
&Acc)),
437-
MIdentity(Identity) {
436+
MIdentity(Identity), MBinaryOp(BOp) {
438437
assert(Acc.get_count() == 1 &&
439438
"Only scalar/1-element reductions are supported now.");
440439
}
@@ -456,7 +455,7 @@ class reduction_impl {
456455
template <
457456
typename _T = T, class _BinaryOperation = BinaryOperation,
458457
enable_if_t<IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
459-
reduction_impl(T *VarPtr, const T &Identity)
458+
reduction_impl(T *VarPtr, const T &Identity, BinaryOperation)
460459
: MIdentity(Identity), MUSMPointer(VarPtr) {
461460
// For now the implementation ignores the identity value given by user
462461
// when the implementation knows the identity.
@@ -478,8 +477,8 @@ class reduction_impl {
478477
template <
479478
typename _T = T, class _BinaryOperation = BinaryOperation,
480479
enable_if_t<!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
481-
reduction_impl(T *VarPtr, const T &Identity)
482-
: MIdentity(Identity), MUSMPointer(VarPtr) {}
480+
reduction_impl(T *VarPtr, const T &Identity, BinaryOperation BOp)
481+
: MIdentity(Identity), MUSMPointer(VarPtr), MBinaryOp(BOp) {}
483482

484483
/// Associates reduction accessor with the given handler and saves reduction
485484
/// buffer so that it is alive until the command group finishes the work.
@@ -563,6 +562,9 @@ class reduction_impl {
563562
return OutPtr;
564563
}
565564

565+
/// Returns the binary operation associated with the reduction.
566+
BinaryOperation getBinaryOperation() const { return MBinaryOp; }
567+
566568
private:
567569
/// Identity of the BinaryOperation.
568570
/// The result of BinaryOperation(X, MIdentity) is equal to X for any X.
@@ -576,6 +578,8 @@ class reduction_impl {
576578
/// USM pointer referencing the memory to where the result of the reduction
577579
/// must be written. Applicable/used only for USM reductions.
578580
T *MUSMPointer = nullptr;
581+
582+
BinaryOperation MBinaryOp;
579583
};
580584

581585
/// These are the forward declaration for the classes that help to create
@@ -794,9 +798,10 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
794798
typename Reduction::result_type ReduIdentity = Redu.getIdentity();
795799
using Name = typename get_reduction_main_kernel_name_t<
796800
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
801+
auto BOp = Redu.getBinaryOperation();
797802
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
798803
// Call user's functions. Reducer.MValue gets initialized there.
799-
typename Reduction::reducer_type Reducer(ReduIdentity);
804+
typename Reduction::reducer_type Reducer(ReduIdentity, BOp);
800805
KernelFunc(NDIt, Reducer);
801806

802807
size_t WGSize = NDIt.get_local_range().size();
@@ -811,7 +816,6 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
811816
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
812817
// LocalReds[WGSize] accumulates last/odd elements when the step
813818
// of tree-reduction loop is not even.
814-
typename Reduction::binary_operation BOp;
815819
size_t PrevStep = WGSize;
816820
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
817821
if (LID < CurStep)
@@ -925,6 +929,7 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
925929
auto LocalReds = Redu.getReadWriteLocalAcc(NumLocalElements, CGH);
926930

927931
auto ReduIdentity = Redu.getIdentity();
932+
auto BOp = Redu.getBinaryOperation();
928933
using Name = typename get_reduction_aux_kernel_name_t<
929934
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
930935
nd_range<1> Range{range<1>(NWorkItems), range<1>(WGSize)};
@@ -943,7 +948,6 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
943948
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
944949
// LocalReds[WGSize] accumulates last/odd elements when the step
945950
// of tree-reduction loop is not even.
946-
typename Reduction::binary_operation BOp;
947951
size_t PrevStep = WGSize;
948952
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
949953
if (LID < CurStep)
@@ -1022,10 +1026,10 @@ template <typename T, class BinaryOperation, int Dims, access::mode AccMode,
10221026
access::placeholder IsPH>
10231027
detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>
10241028
reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
1025-
const T &Identity, BinaryOperation) {
1029+
const T &Identity, BinaryOperation BOp) {
10261030
// The Combiner argument was needed only to define the BinaryOperation param.
10271031
return detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>(
1028-
Acc, Identity);
1032+
Acc, Identity, BOp);
10291033
}
10301034

10311035
/// Creates and returns an object implementing the reduction functionality.
@@ -1050,9 +1054,10 @@ reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
10501054
/// \param Identity, and the binary operation used in the reduction.
10511055
template <typename T, class BinaryOperation>
10521056
detail::reduction_impl<T, BinaryOperation, 0, true, access::mode::read_write>
1053-
reduction(T *VarPtr, const T &Identity, BinaryOperation) {
1057+
reduction(T *VarPtr, const T &Identity, BinaryOperation BOp) {
10541058
return detail::reduction_impl<T, BinaryOperation, 0, true,
1055-
access::mode::read_write>(VarPtr, Identity);
1059+
access::mode::read_write>(VarPtr, Identity,
1060+
BOp);
10561061
}
10571062

10581063
/// Creates and returns an object implementing the reduction functionality.

sycl/test/reduction/reduction_ctor.cpp

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ void test_reducer(Reduction &Redu, T A, T B) {
2323
"Wrong result of binary operation.");
2424
}
2525

26-
template <typename T, typename Reduction>
27-
void test_reducer(Reduction &Redu, T Identity, T A, T B) {
28-
typename Reduction::reducer_type Reducer(Identity);
26+
template <typename T, typename Reduction, typename BinaryOperation>
27+
void test_reducer(Reduction &Redu, T Identity, BinaryOperation BOp, T A, T B) {
28+
typename Reduction::reducer_type Reducer(Identity, BOp);
2929
Reducer.combine(A);
3030
Reducer.combine(B);
3131

32-
typename Reduction::binary_operation BOp;
3332
T ExpectedValue = BOp(A, B);
3433
assert(ExpectedValue == Reducer.MValue &&
3534
"Wrong result of binary operation.");
@@ -40,35 +39,8 @@ class Known;
4039
template <typename T, int Dim, class BinaryOperation>
4140
class Unknown;
4241

43-
template <typename T>
44-
struct Point {
45-
Point() : X(0), Y(0) {}
46-
Point(T X, T Y) : X(X), Y(Y) {}
47-
Point(T V) : X(V), Y(V) {}
48-
bool operator==(const Point &P) const {
49-
return P.X == X && P.Y == Y;
50-
}
51-
T X;
52-
T Y;
53-
};
54-
55-
template <typename T>
56-
bool operator==(const Point<T> &A, const Point<T> &B) {
57-
return A.X == B.X && A.Y == B.Y;
58-
}
59-
60-
template <class T>
61-
struct PointPlus {
62-
using P = Point<T>;
63-
P operator()(const P &A, const P &B) const {
64-
return P(A.X + B.X, A.Y + B.Y);
65-
}
66-
};
67-
6842
template <typename T, int Dim, class BinaryOperation>
69-
void testKnown(T Identity, T A, T B) {
70-
71-
BinaryOperation BOp;
43+
void testKnown(T Identity, BinaryOperation BOp, T A, T B) {
7244
buffer<T, 1> ReduBuf(1);
7345

7446
queue Q;
@@ -81,17 +53,15 @@ void testKnown(T Identity, T A, T B) {
8153
assert(Redu.getIdentity() == Identity &&
8254
"Failed getIdentity() check().");
8355
test_reducer(Redu, A, B);
84-
test_reducer(Redu, Identity, A, B);
56+
test_reducer(Redu, Identity, BOp, A, B);
8557

8658
// Command group must have at least one task in it. Use an empty one.
8759
CGH.single_task<Known<T, Dim, BinaryOperation>>([=]() {});
8860
});
8961
}
9062

91-
template <typename T, int Dim, class BinaryOperation>
92-
void testUnknown(T Identity, T A, T B) {
93-
94-
BinaryOperation BOp;
63+
template <typename T, int Dim, typename KernelName, class BinaryOperation>
64+
void testUnknown(T Identity, BinaryOperation BOp, T A, T B) {
9565
buffer<T, 1> ReduBuf(1);
9666
queue Q;
9767
Q.submit([&](handler &CGH) {
@@ -102,38 +72,46 @@ void testUnknown(T Identity, T A, T B) {
10272
auto Redu = intel::reduction(ReduAcc, Identity, BOp);
10373
assert(Redu.getIdentity() == Identity &&
10474
"Failed getIdentity() check().");
105-
test_reducer(Redu, Identity, A, B);
75+
test_reducer(Redu, Identity, BOp, A, B);
10676

10777
// Command group must have at least one task in it. Use an empty one.
108-
CGH.single_task<Unknown<T, Dim, BinaryOperation>>([=]() {});
78+
CGH.single_task<KernelName>([=]() {});
10979
});
11080
}
11181

11282
template <typename T, class BinaryOperation>
113-
void testBoth(T Identity, T A, T B) {
114-
testKnown<T, 0, BinaryOperation>(Identity, A, B);
115-
testKnown<T, 1, BinaryOperation>(Identity, A, B);
116-
testUnknown<T, 0, BinaryOperation>(Identity, A, B);
117-
testUnknown<T, 1, BinaryOperation>(Identity, A, B);
83+
void testBoth(T Identity, BinaryOperation BOp, T A, T B) {
84+
testKnown<T, 0>(Identity, BOp, A, B);
85+
testKnown<T, 1>(Identity, BOp, A, B);
86+
testUnknown<T, 0, Unknown<T, 0, BinaryOperation>>(Identity, BOp, A, B);
87+
testUnknown<T, 1, Unknown<T, 1, BinaryOperation>>(Identity, BOp, A, B);
11888
}
11989

12090
int main() {
121-
// testKnown does not pass identity to reduction ctor.
122-
testBoth<int, intel::plus<int>>(0, 1, 7);
123-
testBoth<int, std::multiplies<int>>(1, 1, 7);
124-
testBoth<int, intel::bit_or<int>>(0, 1, 8);
125-
testBoth<int, intel::bit_xor<int>>(0, 7, 3);
126-
testBoth<int, intel::bit_and<int>>(~0, 7, 3);
127-
testBoth<int, intel::minimum<int>>((std::numeric_limits<int>::max)(), 7, 3);
128-
testBoth<int, intel::maximum<int>>((std::numeric_limits<int>::min)(), 7, 3);
129-
130-
testBoth<float, intel::plus<float>>(0, 1, 7);
131-
testBoth<float, std::multiplies<float>>(1, 1, 7);
132-
testBoth<float, intel::minimum<float>>(getMaximumFPValue<float>(), 7, 3);
133-
testBoth<float, intel::maximum<float>>(getMinimumFPValue<float>(), 7, 3);
134-
135-
testUnknown<Point<float>, 0, PointPlus<float>>(Point<float>(0), Point<float>(1), Point<float>(7));
136-
testUnknown<Point<float>, 1, PointPlus<float>>(Point<float>(0), Point<float>(1), Point<float>(7));
91+
testBoth<int>(0, intel::plus<int>(), 1, 7);
92+
testBoth<int>(1, std::multiplies<int>(), 1, 7);
93+
testBoth<int>(0, intel::bit_or<int>(), 1, 8);
94+
testBoth<int>(0, intel::bit_xor<int>(), 7, 3);
95+
testBoth<int>(~0, intel::bit_and<int>(), 7, 3);
96+
testBoth<int>((std::numeric_limits<int>::max)(), intel::minimum<int>(), 7, 3);
97+
testBoth<int>((std::numeric_limits<int>::min)(), intel::maximum<int>(), 7, 3);
98+
99+
testBoth<float>(0, intel::plus<float>(), 1, 7);
100+
testBoth<float>(1, std::multiplies<float>(), 1, 7);
101+
testBoth<float>(getMaximumFPValue<float>(), intel::minimum<float>(), 7, 3);
102+
testBoth<float>(getMinimumFPValue<float>(), intel::maximum<float>(), 7, 3);
103+
104+
testUnknown<CustomVec<float>, 0,
105+
Unknown<CustomVec<float>, 0, CustomVecPlus<float>>>(
106+
CustomVec<float>(0), CustomVecPlus<float>(), CustomVec<float>(1),
107+
CustomVec<float>(7));
108+
testUnknown<CustomVec<float>, 1,
109+
Unknown<CustomVec<float>, 1, CustomVecPlus<float>>>(
110+
CustomVec<float>(0), CustomVecPlus<float>(), CustomVec<float>(1),
111+
CustomVec<float>(7));
112+
113+
testUnknown<int, 0, class BitOrName>(
114+
0, [](auto a, auto b) { return a | b; }, 1, 8);
137115

138116
std::cout << "Test passed\n";
139117
return 0;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// UNSUPPORTED: cuda
2+
// Reductions use work-group builtins (e.g. intel::reduce()) not yet supported
3+
// by CUDA.
4+
//
5+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
6+
// RUNx: env SYCL_DEVICE_TYPE=HOST %t.out
7+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
8+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
9+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
10+
11+
// This test performs basic checks of parallel_for(nd_range, reduction, lambda)
12+
13+
#include "reduction_utils.hpp"
14+
#include <CL/sycl.hpp>
15+
#include <cassert>
16+
17+
using namespace cl::sycl;
18+
19+
template <class KernelName, typename T, class BinaryOperation>
20+
void test(T Identity, BinaryOperation BOp, size_t WGSize, size_t NWItems) {
21+
buffer<T, 1> InBuf(NWItems);
22+
buffer<T, 1> OutBuf(1);
23+
24+
// Initialize.
25+
T CorrectOut;
26+
initInputData(InBuf, CorrectOut, Identity, BOp, NWItems);
27+
28+
// Compute.
29+
queue Q;
30+
Q.submit([&](handler &CGH) {
31+
auto In = InBuf.template get_access<access::mode::read>(CGH);
32+
auto Out = OutBuf.template get_access<access::mode::discard_write>(CGH);
33+
auto Redu = intel::reduction(Out, Identity, BOp);
34+
35+
range<1> GlobalRange(NWItems);
36+
range<1> LocalRange(WGSize);
37+
nd_range<1> NDRange(GlobalRange, LocalRange);
38+
CGH.parallel_for<KernelName>(NDRange, Redu,
39+
[=](nd_item<1> NDIt, auto &Sum) {
40+
Sum.combine(In[NDIt.get_global_linear_id()]);
41+
});
42+
});
43+
44+
// Check correctness.
45+
auto Out = OutBuf.template get_access<access::mode::read>();
46+
T ComputedOut = *(Out.get_pointer());
47+
if (ComputedOut != CorrectOut) {
48+
std::cout << "NWItems = " << NWItems << ", WGSize = " << WGSize << "\n";
49+
std::cout << "Computed value: " << ComputedOut
50+
<< ", Expected value: " << CorrectOut << "\n";
51+
assert(0 && "Wrong value.");
52+
}
53+
}
54+
55+
int main() {
56+
test<class AddTestName, int>(
57+
0, [](auto x, auto y) { return (x + y); }, 8, 32);
58+
test<class MulTestName, int>(
59+
0, [](auto x, auto y) { return (x * y); }, 8, 32);
60+
61+
// Check with CUSTOM type.
62+
test<class CustomAddTestname, CustomVec<long long>>(
63+
CustomVec<long long>(0),
64+
[](auto x, auto y) {
65+
CustomVecPlus<long long> BOp;
66+
return BOp(x, y);
67+
},
68+
4, 64);
69+
70+
std::cout << "Test passed\n";
71+
return 0;
72+
}

0 commit comments

Comments
 (0)