Skip to content

Commit 4d2c32f

Browse files
authored
[SYCL] Test for Group Mask feature (intel/llvm-test-suite#441)
1 parent c30fe57 commit 4d2c32f

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

SYCL/GroupMask/Basic.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
2+
// REQUIRES: gpu
3+
// UNSUPPORTED: cuda, hip
4+
// GroupNonUniformBallot capability is supported on Intel GPU only
5+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
6+
7+
//==---------- Basic.cpp - SYCL Group Mask basic test ----------*- C++ -*---==//
8+
//
9+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10+
// See https://llvm.org/LICENSE.txt for license information.
11+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include <CL/sycl.hpp>
16+
using namespace sycl;
17+
constexpr int global_size = 128;
18+
constexpr int local_size = 32;
19+
int main() {
20+
#ifdef SYCL_EXT_ONEAPI_GROUP_MASK
21+
queue Queue;
22+
23+
try {
24+
nd_range<1> NdRange(global_size, local_size);
25+
int Res = 0;
26+
{
27+
buffer resbuf(&Res, range<1>(1));
28+
29+
Queue.submit([&](handler &cgh) {
30+
auto resacc = resbuf.get_access<access::mode::read_write>(cgh);
31+
32+
cgh.parallel_for<class group_mask>(NdRange, [=](nd_item<1> NdItem) {
33+
size_t GID = NdItem.get_global_linear_id();
34+
auto SG = NdItem.get_sub_group();
35+
auto gmask_gid2 =
36+
ext::oneapi::group_ballot(NdItem.get_sub_group(), GID % 2);
37+
auto gmask_gid3 =
38+
ext::oneapi::group_ballot(NdItem.get_sub_group(), GID % 3);
39+
NdItem.barrier();
40+
41+
if (!GID) {
42+
int res = 0;
43+
44+
for (size_t i = 0; i < SG.get_max_local_range()[0]; i++) {
45+
res |= !((gmask_gid2 | gmask_gid3)[i] == (i % 2 || i % 3)) << 1;
46+
res |= !((gmask_gid2 & gmask_gid3)[i] == (i % 2 && i % 3)) << 2;
47+
res |= !((gmask_gid2 ^ gmask_gid3)[i] ==
48+
((bool)(i % 2) ^ (bool)(i % 3)))
49+
<< 3;
50+
}
51+
gmask_gid2 <<= 32;
52+
res |= (gmask_gid2.extract_bits()[2] != 0xaaaaaaaa) << 4;
53+
res |= ((gmask_gid2 >> 8).extract_bits()[3] != 0xaa000000) << 5;
54+
res |= ((gmask_gid3 >> 8).extract_bits()[3] != 0xb6db6d) << 6;
55+
res |= (!gmask_gid2[32] && gmask_gid2[31]) << 7;
56+
gmask_gid3[0] = gmask_gid3[3] = gmask_gid3[6] = true;
57+
res |= (gmask_gid3.extract_bits()[3] != 0xb6db6dff) << 7;
58+
gmask_gid3.reset();
59+
res |= !(gmask_gid3.none() && gmask_gid2.any() && !gmask_gid2.all())
60+
<< 8;
61+
gmask_gid2.set();
62+
res |= !(gmask_gid3.none() && gmask_gid2.any() && gmask_gid2.all())
63+
<< 9;
64+
gmask_gid3.flip();
65+
res |= (gmask_gid3 != gmask_gid2) << 10;
66+
resacc[0] = res;
67+
}
68+
});
69+
});
70+
}
71+
if (Res) {
72+
std::cout << "Unexpected result for group_mask operation: " << Res
73+
<< std::endl;
74+
exit(1);
75+
}
76+
} catch (exception e) {
77+
std::cout << "SYCL exception caught: " << e.what();
78+
exit(1);
79+
}
80+
81+
std::cout << "Test passed." << std::endl;
82+
#else
83+
std::cout << "Test skipped due to missing extension." << std::endl;
84+
#endif
85+
return 0;
86+
}

0 commit comments

Comments
 (0)