Skip to content

Commit afd31d5

Browse files
committed
[SYCL][LIT] Check device allocated USM pointer passed to reduction
Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent 5f8b75e commit afd31d5

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

sycl/test/reduction/reduction_usm.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ using namespace cl::sycl;
2020

2121
template <typename T, int Dim, class BinaryOperation>
2222
class SomeClass;
23+
template <typename T, int Dim, class BinaryOperation>
24+
class Copy1;
2325

2426
template <typename T, int Dim, class BinaryOperation>
2527
void test(T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
@@ -32,11 +34,23 @@ void test(T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
3234
if (AllocType == usm::alloc::host &&
3335
!Dev.get_info<info::device::usm_host_allocations>())
3436
return;
37+
if (AllocType == usm::alloc::device &&
38+
!Dev.get_info<info::device::usm_device_allocations>())
39+
return;
3540

3641
T *ReduVarPtr = (T *)malloc(sizeof(T), Dev, Q.get_context(), AllocType);
3742
if (ReduVarPtr == nullptr)
3843
return;
39-
*ReduVarPtr = Identity;
44+
if (AllocType == usm::alloc::device) {
45+
event E = Q.submit([&](handler &CGH) {
46+
CGH.single_task<class Copy1<T, Dim, BinaryOperation>>([=]() {
47+
*ReduVarPtr = Identity;
48+
});
49+
});
50+
E.wait();
51+
} else {
52+
*ReduVarPtr = Identity;
53+
}
4054

4155
// Initialize.
4256
T CorrectOut;
@@ -60,10 +74,22 @@ void test(T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
6074
Q.wait();
6175

6276
// Check correctness.
63-
if (*ReduVarPtr != CorrectOut) {
77+
T ComputedOut;
78+
if (AllocType == usm::alloc::device) {
79+
buffer<T, 1> Buf(&ComputedOut, range<1>(1));
80+
event E = Q.submit([&](handler &CGH) {
81+
auto OutAcc = Buf.template get_access<access::mode::discard_write>(CGH);
82+
CGH.copy(ReduVarPtr, OutAcc);
83+
});
84+
ComputedOut = (Buf.template get_access<access::mode::read>())[0];
85+
} else {
86+
ComputedOut = *ReduVarPtr;
87+
}
88+
if (ComputedOut != CorrectOut) {
6489
std::cout << "NWItems = " << NWItems << ", WGSize = " << WGSize << "\n";
65-
std::cout << "Computed value: " << *ReduVarPtr
66-
<< ", Expected value: " << CorrectOut << "\n";
90+
std::cout << "Computed value: " << ComputedOut
91+
<< ", Expected value: " << CorrectOut
92+
<< ", AllocMode: " << static_cast<int>(AllocType) << "\n";
6793
assert(0 && "Wrong value.");
6894
}
6995

@@ -74,6 +100,7 @@ template <typename T, int Dim, class BinaryOperation>
74100
void testUSM(T Identity, size_t WGSize, size_t NWItems) {
75101
test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::shared);
76102
test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::host);
103+
test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::device);
77104
}
78105

79106
int main() {

0 commit comments

Comments
 (0)