Skip to content

cuda : use amd wave sharing intrinsics for warp_reduce functions #6522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 111 additions & 7 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,104 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}

#ifdef __HIP_PLATFORM_AMD__
#define AMD_SWIZZLE_MASK(and_mask, or_mask, xor_mask) ((and_mask) | ((or_mask)<<5) | ((xor_mask)<<10)) // 5-bit masks applied sequentially to the thread id
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
#define hip_move_dpph2(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
hip_move_dpph2_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
#define hip_ds_swizzleh2(src, pattern) hip_ds_swizzleh2_N<(pattern)>((src))

template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
typedef union float_b32 {
float val;
int b32;
} float_b32_t;
float_b32_t tmp;
tmp.val = x;
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
return tmp.val;
}

template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
static __device__ __forceinline__ half2 hip_move_dpph2_N(half2 x) {
typedef union half2_b32 {
half2 val;
int b32;
} half2_b32_t;
half2_b32_t tmp;
tmp.val = x;
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
return tmp.val;
}

template <int pattern>
static __device__ __forceinline__ half2 hip_ds_swizzleh2_N(half2 src) {
typedef union half2_b32 {
half2 val;
int b32;
} half2_b32_t;
half2_b32_t tmp;
tmp.val = src;
tmp.b32 = __builtin_amdgcn_ds_swizzle(tmp.b32, pattern);
return tmp.val;
}

static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
x += hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
return x;
}

static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
a.x += __hip_ds_swizzlef(a.x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
a.y += __hip_ds_swizzlef(a.y, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
return a;
}

static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 a) {
half2 tmp;
tmp = hip_ds_swizzleh2(a, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
a.data.x += tmp.data.x;
a.data.y += tmp.data.y;
tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
a.data.x += tmp.data.x;
a.data.y += tmp.data.y;
tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
a.data.x += tmp.data.x;
a.data.y += tmp.data.y;
tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
a.data.x += tmp.data.x;
a.data.y += tmp.data.y;
tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
a.data.x += tmp.data.x;
a.data.y += tmp.data.y;
return a;
}

static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, false));
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, false));
return x;
}
#endif // __HIP_PLATFORM_AMD__
#endif // defined(GGML_USE_HIPBLAS)

#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
Expand Down Expand Up @@ -349,33 +447,35 @@ static __device__ void no_device_code(
#endif // __CUDA_ARCH__

static __device__ __forceinline__ float warp_reduce_sum(float x) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_sum_impl_amd(x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}

static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_sum_impl_amd(a);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}

static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if FP16_AVAILABLE

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
reinterpret_cast<half&>(a.x) += __low2half(a_other);
reinterpret_cast<half&>(a.y) += __high2half(a_other);
}
return a;
return warp_reduce_sum_impl_amd(a);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
Expand All @@ -391,11 +491,15 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
}

static __device__ __forceinline__ float warp_reduce_max(float x) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return warp_reduce_max_impl_amd(x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
Expand Down
Loading