Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,59 @@ static __global__ void rms_norm_back_f32(
}
}

// Wave32-optimized RMSNorm for RDNA2: 128 threads = 4 × Wave32
// Uses manual warp-shuffle + shared-memory cross-warp reduce instead of
// the generic block_reduce<> path (which expects extern shared mem).
static __global__ void rms_norm_f32_wave32(
const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x; // [0, 128)

x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;

// Phase 1: accumulate sum of squares
float tmp = 0.0f;
for (int col = tid; col < ncols; col += 128) {
const float xi = x[col];
tmp += xi * xi;
}

// Phase 2: intra-warp reduce (Wave32 = 32 threads)
for (int offset = 16; offset > 0; offset >>= 1)
tmp += __shfl_xor_sync(0xffffffff, tmp, offset, 32);

// Cross-warp reduce via shared memory (4 warps of 32)
__shared__ float warp_sums[4];
__shared__ float s_variance;

const int warp_id = tid >> 5;
const int lane_id = tid & 31;

if (lane_id == 0) warp_sums[warp_id] = tmp;
__syncthreads();

if (warp_id == 0 && lane_id < 4) {
float val = warp_sums[lane_id];
for (int offset = 2; offset > 0; offset >>= 1)
val += __shfl_xor_sync(0xffffffff, val, offset, 32);
if (lane_id == 0)
s_variance = rsqrtf(val / ncols + eps);
}
__syncthreads();

// Phase 3: apply scale
for (int col = tid; col < ncols; col += 128)
dst[col] = s_variance * x[col];
}

// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -298,7 +351,12 @@ static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (GGML_CUDA_CC_IS_RDNA2(cc)) {
// 128 threads = 4 × Wave32: optimal occupancy on RDNA2 (gfx1030/gfx1031)
const dim3 block_dims(128, 1, 1);
rms_norm_f32_wave32<<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
Expand Down
7 changes: 5 additions & 2 deletions ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,11 @@ static void rope_neox_cuda(const T * x,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
// RDNA2 Wave32: use 32 threads in Y to align with one Wave32 wavefront
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int rope_block = GGML_CUDA_CC_IS_RDNA2(cc) ? 32 : CUDA_ROPE_BLOCK_SIZE;
const dim3 block_dims(1, rope_block, 1);
const int n_blocks_x = (ne00 + 2 * rope_block - 1) / (2 * rope_block);
const dim3 block_nums(nr, n_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f / n_dims);
Expand Down