Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL] Add a check for interop_handle::get_backend() method #258

Merged
merged 1 commit into from
May 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 25 additions & 58 deletions SYCL/HostInteropTask/interop-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ template <typename T> class Modifier;

template <typename T> class Init;

template <typename BufferT, typename ValueT>
void checkBufferValues(BufferT Buffer, ValueT Value) {
auto Acc = Buffer.template get_access<mode::read>();
for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
if (Acc[Idx] != Value) {
std::cerr << "buffer[" << Idx << "] = " << Acc[Idx]
<< ", expected val = " << Value << std::endl;
assert(0 && "Invalid data in the buffer");
}
}
}

template <typename DataT>
void copy(buffer<DataT, 1> &Src, buffer<DataT, 1> &Dst, queue &Q) {
Q.submit([&](handler &CGH) {
Expand All @@ -41,6 +53,11 @@ void copy(buffer<DataT, 1> &Src, buffer<DataT, 1> &Dst, queue &Q) {

if (RC != CL_SUCCESS)
throw runtime_error("Can't wait for event on buffer copy", RC);

if (Q.get_backend() != IH.get_backend())
throw runtime_error(
"interop_handle::get_backend() returned a wrong value",
CL_INVALID_VALUE);
});
});
}
Expand Down Expand Up @@ -89,22 +106,8 @@ void test1() {
copy(Buffer2, Buffer1, Q);
}

{
auto Acc = Buffer1.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "First buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the first buffer");
}
}
{
auto Acc = Buffer2.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "Second buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the second buffer");
}
}
checkBufferValues(Buffer1, COUNT - 1);
checkBufferValues(Buffer2, COUNT - 1);
}

// Same as above, but performing each command group on a separate SYCL queue
Expand All @@ -128,23 +131,8 @@ void test2() {
modify(Buffer2, Q);
copy(Buffer2, Buffer1, Q);
}

{
auto Acc = Buffer1.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "First buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the first buffer");
}
}
{
auto Acc = Buffer2.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "Second buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the second buffer");
}
}
checkBufferValues(Buffer1, COUNT - 1);
checkBufferValues(Buffer2, COUNT - 1);
}

// Same as above but with queue constructed out of context
Expand All @@ -168,23 +156,8 @@ void test2_1() {
modify(Buffer2, Q);
copy(Buffer2, Buffer1, Q);
}

{
auto Acc = Buffer1.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "First buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the first buffer");
}
}
{
auto Acc = Buffer2.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "Second buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert((Acc[Idx] == COUNT - 1) && "Invalid data in the second buffer");
}
}
checkBufferValues(Buffer1, COUNT - 1);
checkBufferValues(Buffer2, COUNT - 1);
}

// A test that does a clEnqueueWait inside the interop scope, for an event
Expand Down Expand Up @@ -245,14 +218,7 @@ void test5() {

copy(Buffer1, Buffer2, Q);

{
auto Acc = Buffer2.get_access<mode::read>();

for (size_t Idx = 0; Idx < Acc.get_count(); ++Idx) {
std::cout << "Second buffer [" << Idx << "] = " << Acc[Idx] << std::endl;
assert(Acc[Idx] == 123);
}
}
checkBufferValues(Buffer2, static_cast<int>(123));
}

// The test checks that an exception which is thrown from host_task body
Expand Down Expand Up @@ -292,5 +258,6 @@ int main() {
test4();
test5();
test6();
std::cout << "Test PASSED" << std::endl;
return 0;
}