@@ -20,6 +20,8 @@ using namespace cl::sycl;
20
20
21
21
template <typename T, int Dim, class BinaryOperation >
22
22
class SomeClass ;
23
+ template <typename T, int Dim, class BinaryOperation >
24
+ class Copy1 ;
23
25
24
26
template <typename T, int Dim, class BinaryOperation >
25
27
void test (T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
@@ -32,11 +34,23 @@ void test(T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
32
34
if (AllocType == usm::alloc::host &&
33
35
!Dev.get_info <info::device::usm_host_allocations>())
34
36
return ;
37
+ if (AllocType == usm::alloc::device &&
38
+ !Dev.get_info <info::device::usm_device_allocations>())
39
+ return ;
35
40
36
41
T *ReduVarPtr = (T *)malloc (sizeof (T), Dev, Q.get_context (), AllocType);
37
42
if (ReduVarPtr == nullptr )
38
43
return ;
39
- *ReduVarPtr = Identity;
44
+ if (AllocType == usm::alloc::device) {
45
+ event E = Q.submit ([&](handler &CGH) {
46
+ CGH.single_task <class Copy1 <T, Dim, BinaryOperation>>([=]() {
47
+ *ReduVarPtr = Identity;
48
+ });
49
+ });
50
+ E.wait ();
51
+ } else {
52
+ *ReduVarPtr = Identity;
53
+ }
40
54
41
55
// Initialize.
42
56
T CorrectOut;
@@ -60,10 +74,22 @@ void test(T Identity, size_t WGSize, size_t NWItems, usm::alloc AllocType) {
60
74
Q.wait ();
61
75
62
76
// Check correctness.
63
- if (*ReduVarPtr != CorrectOut) {
77
+ T ComputedOut;
78
+ if (AllocType == usm::alloc::device) {
79
+ buffer<T, 1 > Buf (&ComputedOut, range<1 >(1 ));
80
+ event E = Q.submit ([&](handler &CGH) {
81
+ auto OutAcc = Buf.template get_access <access::mode::discard_write>(CGH);
82
+ CGH.copy (ReduVarPtr, OutAcc);
83
+ });
84
+ ComputedOut = (Buf.template get_access <access::mode::read>())[0 ];
85
+ } else {
86
+ ComputedOut = *ReduVarPtr;
87
+ }
88
+ if (ComputedOut != CorrectOut) {
64
89
std::cout << " NWItems = " << NWItems << " , WGSize = " << WGSize << " \n " ;
65
- std::cout << " Computed value: " << *ReduVarPtr
66
- << " , Expected value: " << CorrectOut << " \n " ;
90
+ std::cout << " Computed value: " << ComputedOut
91
+ << " , Expected value: " << CorrectOut
92
+ << " , AllocMode: " << static_cast <int >(AllocType) << " \n " ;
67
93
assert (0 && " Wrong value." );
68
94
}
69
95
@@ -74,6 +100,7 @@ template <typename T, int Dim, class BinaryOperation>
74
100
void testUSM (T Identity, size_t WGSize, size_t NWItems) {
75
101
test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::shared);
76
102
test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::host);
103
+ test<T, Dim, BinaryOperation>(Identity, WGSize, NWItems, usm::alloc::device);
77
104
}
78
105
79
106
int main () {
0 commit comments