@@ -181,15 +181,18 @@ static __global__ void rms_norm_f32(const float * x,
181181 // sum up partial sums
182182 tmp = warp_reduce_sum (tmp);
183183 if constexpr (block_size > WARP_SIZE) {
184- static_assert (block_size == 1024 , " unexpected block_size" );
184+ static_assert (( block_size <= 1024 ) && (block_size % 32 == 0 ) , " unexpected block_size" );
185185 __shared__ float s_sum[32 ];
186- const int warp_id = threadIdx . x / WARP_SIZE;
187- const int lane_id = threadIdx . x % WARP_SIZE;
186+ const int warp_id = tid / WARP_SIZE;
187+ const int lane_id = tid % WARP_SIZE;
188188 if (lane_id == 0 ) {
189189 s_sum[warp_id] = tmp;
190190 }
191191 __syncthreads ();
192- tmp = s_sum[lane_id];
192+ tmp = 0 .0f ;
193+ if (lane_id < (block_size / WARP_SIZE)) {
194+ tmp = s_sum[lane_id];
195+ }
193196 tmp = warp_reduce_sum (tmp);
194197 }
195198
@@ -370,8 +373,8 @@ static void rms_norm_f32_cuda(
370373 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
371374 const dim3 blocks_num (nrows, nchannels, nsamples);
372375 if (ncols < 1024 ) {
373- const dim3 block_dims (WARP_SIZE , 1 , 1 );
374- rms_norm_f32<WARP_SIZE , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
376+ const dim3 block_dims (256 , 1 , 1 );
377+ rms_norm_f32<256 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
375378 } else {
376379 const dim3 block_dims (1024 , 1 , 1 );
377380 rms_norm_f32<1024 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
@@ -420,8 +423,8 @@ static void rms_norm_mul_f32_cuda(const float * x,
420423 uint32_t mp_mul_samples, L_mul_samples;
421424 init_fastdiv_values (mul_nsamples, mp_mul_samples, L_mul_samples);
422425 if (ncols < 1024 ) {
423- const dim3 block_dims (WARP_SIZE , 1 , 1 );
424- rms_norm_f32<WARP_SIZE , true ><<<blocks_num, block_dims, 0 , stream>>> (x,
426+ const dim3 block_dims (256 , 1 , 1 );
427+ rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (x,
425428 dst,
426429 ncols,
427430 stride_row,
@@ -489,8 +492,8 @@ static void rms_norm_mul_f32_cuda(const float * x,
489492 uint32_t mp_add_samples, L_add_samples;
490493 init_fastdiv_values (add_nsamples, mp_add_samples, L_add_samples);
491494 if (ncols < 1024 ) {
492- const dim3 block_dims (WARP_SIZE , 1 , 1 );
493- rms_norm_f32<WARP_SIZE , true , true ><<<blocks_num, block_dims, 0 , stream>>> (x,
495+ const dim3 block_dims (256 , 1 , 1 );
496+ rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (x,
494497 dst,
495498 ncols,
496499 stride_row,
0 commit comments