Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL] Add test for classes implicitly converted from items in parallel_for #607

Merged
merged 5 commits into from
Dec 19, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions SYCL/Basic/parallel_for_user_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: %HOST_RUN_PLACEHOLDER %t.out
// RUN: %CPU_RUN_PLACEHOLDER %t.out
// RUN: %GPU_RUN_PLACEHOLDER %t.out
// RUN: %ACC_RUN_PLACEHOLDER %t.out

// This test performs basic check of supporting user defined class that are
// implicitly converted from sycl::item/sycl::nd_item in parallel_for.

#include <CL/sycl.hpp>
#include <iostream>

template <int Dimensions> class item_wrapper {
public:
item_wrapper(sycl::item<Dimensions> it) : m_item(it) {}

size_t get() { return m_item; }

private:
sycl::item<Dimensions> m_item;
};

template <int Dimensions> class nd_item_wrapper {
public:
nd_item_wrapper(sycl::nd_item<Dimensions> it) : m_item(it) {}

size_t get() { return m_item.get_global_linear_id(); }

private:
sycl::nd_item<Dimensions> m_item;
};

int main() {
sycl::queue q;

// Initialize data array
const int sz = 16;
int data[sz] = {0};
for (int i = 0; i < sz; ++i) {
data[i] = i;
}

// Check user defined sycl::item wrapper
sycl::buffer<int> data_buf(data, sz);
q.submit([&](sycl::handler &h) {
auto buf_acc = data_buf.get_access<sycl::access::mode::read_write>(h);
h.parallel_for(sycl::range<1>{sz},
[=](item_wrapper<1> item) { buf_acc[item.get()] += 1; });
});
q.wait();
bool failed = false;

{
auto buf_acc = data_buf.get_access<sycl::access::mode::read>();
for (int i = 0; i < sz; ++i) {
failed |= (buf_acc[i] != i + 1);
}
if (failed) {
std::cout << "item_wrapper check failed" << std::endl;
return 1;
}
}

// Check user defined sycl::nd_item wrapper
q.submit([&](sycl::handler &h) {
auto buf_acc = data_buf.get_access<sycl::access::mode::read_write>(h);
h.parallel_for(sycl::nd_range<1>{sz, 2},
[=](nd_item_wrapper<1> item) { buf_acc[item.get()] += 1; });
});
q.wait();

{
auto buf_acc = data_buf.get_access<sycl::access::mode::read>();
for (int i = 0; i < sz; ++i) {
failed |= (buf_acc[i] != i + 2);
}
if (failed) {
std::cout << "nd_item_wrapper check failed" << std::endl;
return 1;
}
}

std::cout << "Test passed" << std::endl;
return 0;
}