Skip to content

Commit 527e7ea

Browse files
committed
[OpenMP][OMPX] Add ballot_sync
1 parent 477b48e commit 527e7ea

File tree

5 files changed

+73
-0
lines changed

5 files changed

+73
-0
lines changed

offload/DeviceRTL/include/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
2525

2626
int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);
2727

28+
uint64_t ballotSync(uint64_t Mask, int32_t Pred);
29+
2830
/// Return \p LowBits and \p HighBits packed into a single 64 bit value.
2931
uint64_t pack(uint32_t LowBits, uint32_t HighBits);
3032

offload/DeviceRTL/src/Mapping.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,4 +364,8 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
364364
_TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
365365
_TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
366366

367+
extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
368+
return utils::ballotSync(mask, pred);
369+
}
370+
367371
#pragma omp end declare target

offload/DeviceRTL/src/Utils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
3737
int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta,
3838
int32_t Width);
3939

40+
uint64_t ballotSync(uint64_t Mask, int32_t Pred);
41+
4042
/// AMDGCN Implementation
4143
///
4244
///{
@@ -57,6 +59,10 @@ int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t LaneDelta,
5759
return __builtin_amdgcn_ds_bpermute(Index << 2, Var);
5860
}
5961

62+
uint64_t ballotSync(uint64_t Mask, int32_t Pred) {
63+
return Mask & __builtin_amdgcn_ballot_w64(Pred);
64+
}
65+
6066
bool isSharedMemPtr(const void *Ptr) {
6167
return __builtin_amdgcn_is_shared(
6268
(const __attribute__((address_space(0))) void *)Ptr);
@@ -80,6 +86,10 @@ int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width) {
8086
return __nvvm_shfl_sync_down_i32(Mask, Var, Delta, T);
8187
}
8288

89+
uint64_t ballotSync(uint64_t Mask, int32_t Pred) {
90+
return __nvvm_vote_ballot_sync(static_cast<uint32_t>(Mask), Pred);
91+
}
92+
8393
bool isSharedMemPtr(const void *Ptr) { return __nvvm_isspacep_shared(Ptr); }
8494

8595
#pragma omp end declare variant
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %libomptarget-compilexx-run-and-check-generic
2+
//
3+
// UNSUPPORTED: x86_64-pc-linux-gnu
4+
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
5+
// UNSUPPORTED: aarch64-unknown-linux-gnu
6+
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
7+
// UNSUPPORTED: s390x-ibm-linux-gnu
8+
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
9+
10+
#if defined __AMDGCN_WAVEFRONT_SIZE && __AMDGCN_WAVEFRONT_SIZE == 64
11+
#define MASK 0xaaaaaaaaaaaaaaaa
12+
#else
13+
#define MASK 0xaaaaaaaa
14+
#endif
15+
16+
#include <assert.h>
17+
#include <ompx.h>
18+
#include <stdint.h>
19+
#include <stdio.h>
20+
#include <stdlib.h>
21+
22+
int main(int argc, char *argv[]) {
23+
const int num_blocks = 1;
24+
const int block_size = 64;
25+
const int N = num_blocks * block_size;
26+
uint64_t *data = (int *)malloc(N * sizeof(uint64_t));
27+
28+
for (int i = 0; i < N; ++i)
29+
data[i] = i & 0x1;
30+
31+
#pragma omp target teams ompx_bare num_teams(num_blocks) thread_limit(block_size) map(tofrom: data[0:N])
32+
{
33+
int tid = ompx_thread_id_x();
34+
uint64_t mask = ompx_ballot_sync(~0U, data[tid]);
35+
data[tid] += mask;
36+
}
37+
38+
for (int i = 0; i < N; ++i)
39+
assert(data[i] == ((i & 0x1) + MASK));
40+
41+
// CHECK: PASS
42+
printf("PASS\n");
43+
44+
return 0;
45+
}

openmp/runtime/src/include/ompx.h.var

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef __OMPX_H
1010
#define __OMPX_H
1111

12+
typedef unsigned long uint64_t;
13+
1214
#ifdef __cplusplus
1315
extern "C" {
1416
#endif
@@ -81,6 +83,10 @@ _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C(void, sync_block_divergent, int Ordering,
8183
#undef _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C
8284
///}
8385

86+
static INLINE uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
87+
__builtin_trap();
88+
}
89+
8490
#pragma omp end declare variant
8591

8692
/// ompx_{sync_block}_{,divergent}
@@ -109,6 +115,8 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)
109115
#undef _TGT_KERNEL_LANGUAGE_DECL_GRID_C
110116
///}
111117

118+
uint64_t ompx_ballot_sync(uint64_t mask, int pred);
119+
112120
#ifdef __cplusplus
113121
}
114122
#endif
@@ -160,6 +168,10 @@ _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_CXX(void, sync_block_divergent,
160168
#undef _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_CXX
161169
///}
162170

171+
static INLINE uint64_t ballot_sync(uint64_t mask, int pred) {
172+
return ompx_ballot_sync(mask, pred);
173+
}
174+
163175
} // namespace ompx
164176
#endif
165177

0 commit comments

Comments
 (0)