@@ -29,7 +29,7 @@ namespace norm {
2929
3030template <uint32_t VEC_SIZE, typename T>
3131__global__ void RMSNormKernel (T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
32- const uint32_t d, float eps) {
32+ const uint32_t d, float weight_bias, float eps) {
3333 const uint32_t bx = blockIdx .x ;
3434 const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
3535 constexpr uint32_t warp_size = 32 ;
@@ -87,7 +87,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
8787 }
8888#pragma unroll
8989 for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
90- output_vec[j] = float (input_vec[j]) * rms_rcp * float (weight_vec[j]);
90+ output_vec[j] = float (input_vec[j]) * rms_rcp * (weight_bias + float (weight_vec[j]) );
9191 }
9292 if ((i * num_threads + thread_id) * VEC_SIZE < d) {
9393 output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -105,7 +105,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
105105 dim3 nblks (batch_size);
106106 dim3 nthrs (32 , num_warps);
107107 const uint32_t smem_size = num_warps * sizeof (float );
108- void * args[] = {&input, &weight, &output, &d, &eps};
108+ float weight_bias = 0 .f ;
109+ void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
109110
110111 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
111112 auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -116,7 +117,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
116117
117118template <uint32_t VEC_SIZE, typename T>
118119__global__ void FusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
119- T* __restrict__ weight, const uint32_t d, float eps) {
120+ T* __restrict__ weight, const uint32_t d, float weight_bias,
121+ float eps) {
120122 const uint32_t bx = blockIdx .x ;
121123 const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
122124 constexpr uint32_t warp_size = 32 ;
@@ -187,7 +189,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
187189 }
188190#pragma unroll
189191 for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
190- input_vec[j] = x_vec[j] * rms_rcp * float (weight_vec[j]);
192+ input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float (weight_vec[j]) );
191193 }
192194 if ((i * num_threads + thread_id) * VEC_SIZE < d) {
193195 input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -205,7 +207,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
205207 dim3 nblks (batch_size);
206208 dim3 nthrs (32 , num_warps);
207209 const uint32_t smem_size = (num_warps + d) * sizeof (float );
208- void * args[] = {&input, &residual, &weight, &d, &eps};
210+ float weight_bias = 0 .f ;
211+ void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
209212
210213 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
211214 auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
@@ -215,73 +218,6 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
215218 return cudaSuccess;
216219}
217220
218- template <uint32_t VEC_SIZE, typename T>
219- __global__ void GemmaRMSNormKernel (T* __restrict__ input, T* __restrict__ weight,
220- T* __restrict__ output, const uint32_t d, float eps) {
221- const uint32_t bx = blockIdx .x ;
222- const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
223- constexpr uint32_t warp_size = 32 ;
224- const uint32_t num_warps = blockDim .y ;
225- const uint32_t thread_id = tx + ty * warp_size;
226- const uint32_t num_threads = num_warps * warp_size;
227- const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
228- extern __shared__ float smem[];
229-
230- float sum_sq = 0 .f ;
231-
232- for (uint32_t i = 0 ; i < rounds; i++) {
233- vec_t <T, VEC_SIZE> input_vec;
234- input_vec.fill (0 .f );
235- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
236- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
237- }
238- #pragma unroll
239- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
240- sum_sq += float (input_vec[j]) * float (input_vec[j]);
241- }
242- }
243-
244- // first, warp reduce sum
245- #pragma unroll
246- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
247- sum_sq += math::shfl_xor_sync (sum_sq, offset);
248- }
249-
250- smem[ty] = sum_sq;
251- __syncthreads ();
252- // then, cross warp reduce sum using only the first warp
253- if (ty == 0 ) {
254- sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
255- #pragma unroll
256- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
257- sum_sq += math::shfl_xor_sync (sum_sq, offset);
258- }
259- smem[0 ] = sum_sq;
260- }
261- __syncthreads ();
262-
263- float rms_rcp = math::rsqrt (smem[0 ] / (float (d) + eps));
264-
265- for (uint32_t i = 0 ; i < rounds; i++) {
266- vec_t <T, VEC_SIZE> input_vec;
267- vec_t <T, VEC_SIZE> weight_vec;
268- vec_t <T, VEC_SIZE> output_vec;
269- input_vec.fill (0 .f );
270- weight_vec.fill (0 .f );
271- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
272- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
273- weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
274- }
275- #pragma unroll
276- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
277- output_vec[j] = float (input_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
278- }
279- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
280- output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
281- }
282- }
283- }
284-
285221template <typename T>
286222cudaError_t GemmaRMSNorm (T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
287223 float eps = 1e-5 , cudaStream_t stream = 0 ) {
@@ -292,92 +228,16 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
292228 dim3 nblks (batch_size);
293229 dim3 nthrs (32 , num_warps);
294230 const uint32_t smem_size = num_warps * sizeof (float );
295- void * args[] = {&input, &weight, &output, &d, &eps};
231+ float weight_bias = 1 .f ;
232+ void * args[] = {&input, &weight, &output, &d, &weight_bias, &eps};
296233
297234 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
298- auto kernel = GemmaRMSNormKernel <VEC_SIZE, T>;
235+ auto kernel = RMSNormKernel <VEC_SIZE, T>;
299236 FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
300237 });
301238 return cudaSuccess;
302239}
303240
304- template <uint32_t VEC_SIZE, typename T>
305- __global__ void GemmaFusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
306- T* __restrict__ weight, const uint32_t d, float eps) {
307- const uint32_t bx = blockIdx .x ;
308- const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
309- constexpr uint32_t warp_size = 32 ;
310- const uint32_t num_warps = blockDim .y ;
311- const uint32_t thread_id = tx + ty * warp_size;
312- const uint32_t num_threads = num_warps * warp_size;
313- const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
314- extern __shared__ float smem[];
315-
316- float sum_sq = 0 .f ;
317-
318- for (uint32_t i = 0 ; i < rounds; i++) {
319- vec_t <T, VEC_SIZE> input_vec;
320- input_vec.fill (0 .f );
321- vec_t <T, VEC_SIZE> residual_vec;
322- residual_vec.fill (0 .f );
323- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
324- input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
325- residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
326- }
327- #pragma unroll
328- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
329- float x = float (input_vec[j]);
330- x += float (residual_vec[j]);
331- sum_sq += x * x;
332- residual_vec[j] = (T)x;
333- }
334- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
335- residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
336- }
337- }
338-
339- // first, warp reduce sum
340- #pragma unroll
341- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
342- sum_sq += math::shfl_xor_sync (sum_sq, offset);
343- }
344-
345- smem[ty] = sum_sq;
346- __syncthreads ();
347- // then, cross warp reduce sum using only the first warp
348- if (ty == 0 ) {
349- sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
350- #pragma unroll
351- for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
352- sum_sq += math::shfl_xor_sync (sum_sq, offset);
353- }
354- smem[0 ] = sum_sq;
355- }
356- __syncthreads ();
357-
358- float rms_rcp = math::rsqrt (smem[0 ] / (float (d) + eps));
359-
360- for (uint32_t i = 0 ; i < rounds; i++) {
361- vec_t <T, VEC_SIZE> input_vec;
362- vec_t <T, VEC_SIZE> weight_vec;
363- vec_t <T, VEC_SIZE> residual_vec;
364- input_vec.fill (0 .f );
365- weight_vec.fill (0 .f );
366- residual_vec.fill (0 .f );
367- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
368- weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
369- residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
370- }
371- #pragma unroll
372- for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
373- input_vec[j] = float (residual_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
374- }
375- if ((i * num_threads + thread_id) * VEC_SIZE < d) {
376- input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
377- }
378- }
379- }
380-
381241template <typename T>
382242cudaError_t GemmaFusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
383243 float eps = 1e-5 , cudaStream_t stream = 0 ) {
@@ -387,11 +247,12 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
387247 const uint32_t num_warps = ceil_div (block_size, 32 );
388248 dim3 nblks (batch_size);
389249 dim3 nthrs (32 , num_warps);
390- const uint32_t smem_size = num_warps * sizeof (float );
391- void * args[] = {&input, &residual, &weight, &d, &eps};
250+ const uint32_t smem_size = (num_warps + d) * sizeof (float );
251+ float weight_bias = 1 .f ;
252+ void * args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};
392253
393254 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
394- auto kernel = GemmaFusedAddRMSNormKernel <VEC_SIZE, T>;
255+ auto kernel = FusedAddRMSNormKernel <VEC_SIZE, T>;
395256 FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
396257 });
397258
0 commit comments