Skip to content

Commit

Permalink
[OpenMP][OMPX] Add shfl_down_sync (llvm#93311)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiltian authored May 24, 2024
1 parent d07362f commit 4fb02de
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 6 deletions.
2 changes: 2 additions & 0 deletions offload/DeviceRTL/include/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);

int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);

int64_t shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta, int32_t Width);

uint64_t ballotSync(uint64_t Mask, int32_t Pred);

/// Return \p LowBits and \p HighBits packed into a single 64 bit value.
Expand Down
24 changes: 23 additions & 1 deletion offload/DeviceRTL/src/Mapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,30 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
_TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
_TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)

extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
extern "C" {
uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
return utils::ballotSync(mask, pred);
}

int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
return utils::shuffleDown(mask, var, delta, width);
}

float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
int width) {
return utils::convertViaPun<float>(utils::shuffleDown(
mask, utils::convertViaPun<int32_t>(var), delta, width));
}

long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
return utils::shuffleDown(mask, var, delta, width);
}

double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
int width) {
return utils::convertViaPun<double>(utils::shuffleDown(
mask, utils::convertViaPun<int64_t>(var), delta, width));
}
}

#pragma omp end declare target
15 changes: 10 additions & 5 deletions offload/DeviceRTL/src/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
return impl::shuffleDown(Mask, Var, Delta, Width);
}

int64_t utils::shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta,
int32_t Width) {
uint32_t Lo, Hi;
utils::unpack(Var, Lo, Hi);
Hi = impl::shuffleDown(Mask, Hi, Delta, Width);
Lo = impl::shuffleDown(Mask, Lo, Delta, Width);
return utils::pack(Lo, Hi);
}

uint64_t utils::ballotSync(uint64_t Mask, int32_t Pred) {
return impl::ballotSync(Mask, Pred);
}
Expand All @@ -125,11 +134,7 @@ int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
}

int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) {
uint32_t lo, hi;
utils::unpack(Val, lo, hi);
hi = impl::shuffleDown(lanes::All, hi, Delta, Width);
lo = impl::shuffleDown(lanes::All, lo, Delta, Width);
return utils::pack(lo, hi);
return utils::shuffleDown(lanes::All, Val, Delta, Width);
}
}

Expand Down
67 changes: 67 additions & 0 deletions offload/test/offloading/ompx_bare_shfl_down_sync.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: %libomptarget-compilexx-run-and-check-generic
//
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO

#ifdef __AMDGCN_WAVEFRONT_SIZE
#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
#else
#define WARP_SIZE 32
#endif

#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <limits>
#include <ompx.h>
#include <type_traits>

template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
bool equal(T LHS, T RHS) {
return LHS == RHS;
}

template <typename T,
std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
bool equal(T LHS, T RHS) {
return std::abs(LHS - RHS) < std::numeric_limits<T>::epsilon();
}

template <typename T> void test() {
constexpr const int num_blocks = 1;
constexpr const int block_size = 256;
constexpr const int N = num_blocks * block_size;
T *data = new T[N];

for (int i = 0; i < N; ++i)
data[i] = i;

#pragma omp target teams ompx_bare num_teams(num_blocks) \
thread_limit(block_size) map(tofrom : data[0 : N])
{
int tid = ompx_thread_id_x();
data[tid] = ompx::shfl_down_sync(~0U, data[tid], 1);
}

for (int i = N - 1; i > 0; i -= WARP_SIZE)
for (int j = i; j > i - WARP_SIZE; --j)
assert(equal(data[i], data[i - 1]));

delete[] data;
}

int main(int argc, char *argv[]) {
test<int32_t>();
test<int64_t>();
test<float>();
test<double>();
// CHECK: PASS
printf("PASS\n");

return 0;
}
52 changes: 52 additions & 0 deletions openmp/runtime/src/include/ompx.h.var
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#ifndef __OMPX_H
#define __OMPX_H

#ifdef __AMDGCN_WAVEFRONT_SIZE
#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
#else
#define __WARP_SIZE 32
#endif

typedef unsigned long uint64_t;

#ifdef __cplusplus
Expand Down Expand Up @@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
__builtin_trap();
}

/// ompx_shfl_down_sync_{i,f,l,d}
///{
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY) \
static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, \
unsigned delta, int width) { \
__builtin_trap(); \
}

_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d)

#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL
///}

#pragma omp end declare variant

/// ompx_{sync_block}_{,divergent}
Expand Down Expand Up @@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)

uint64_t ompx_ballot_sync(uint64_t mask, int pred);

/// ompx_shfl_down_sync_{i,f,l,d}
///{
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta, \
int width = __WARP_SIZE);

_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)

#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
///}

#ifdef __cplusplus
}
#endif
Expand Down Expand Up @@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) {
return ompx_ballot_sync(mask, pred);
}

/// shfl_down_sync
///{
#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \
static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta, \
int width = __WARP_SIZE) { \
return ompx_shfl_down_sync_##TY(mask, var, delta, width); \
}

_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)

#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
///}

} // namespace ompx
#endif

Expand Down

0 comments on commit 4fb02de

Please sign in to comment.