@@ -28,7 +28,7 @@ namespace flashinfer {
2828namespace norm {
2929
3030template <uint32_t VEC_SIZE, typename T>
31- __global__ void RMSNormKernel (T* __restrict__ x , T* __restrict__ w , T* __restrict__ y ,
31+ __global__ void RMSNormKernel (T* __restrict__ input , T* __restrict__ weight , T* __restrict__ output ,
3232 const uint32_t d, float eps) {
3333 const uint32_t bx = blockIdx .x ;
3434 const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
@@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
4343 float sum_sq = 0 .f ;
4444
4545 for (uint32_t i = 0 ; i < rounds; i++) {
46- vec_t <T, VEC_SIZE> x_vec ;
47- x_vec .fill (0 );
46+ vec_t <T, VEC_SIZE> input_vec ;
47+ input_vec .fill (0 );
4848 if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49- x_vec .load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
49+ input_vec .load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
5050 }
5151#pragma unroll
5252 for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
53- sum_sq += float (x_vec [j]) * float (x_vec [j]);
53+ sum_sq += float (input_vec [j]) * float (input_vec [j]);
5454 }
5555 }
5656
@@ -76,36 +76,36 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
7676 float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
7777
7878 for (uint32_t i = 0 ; i < rounds; i++) {
79- vec_t <T, VEC_SIZE> x_vec ;
80- vec_t <T, VEC_SIZE> w_vec ;
81- vec_t <T, VEC_SIZE> y_vec ;
82- x_vec .fill (0 );
83- w_vec .fill (0 );
79+ vec_t <T, VEC_SIZE> input_vec ;
80+ vec_t <T, VEC_SIZE> weight_vec ;
81+ vec_t <T, VEC_SIZE> output_vec ;
82+ input_vec .fill (0 );
83+ weight_vec .fill (0 );
8484 if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85- x_vec .load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86- w_vec .load (w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
85+ input_vec .load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86+ weight_vec .load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
8787 }
8888#pragma unroll
8989 for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
90- y_vec [j] = float (x_vec [j]) * rms_rcp * float (w_vec [j]);
90+ output_vec [j] = float (input_vec [j]) * rms_rcp * float (weight_vec [j]);
9191 }
9292 if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93- y_vec .store (y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
93+ output_vec .store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
9494 }
9595 }
9696}
9797
9898template <typename T>
99- cudaError_t RMSNorm (T* x , T* w , T* y , uint32_t batch_size, uint32_t d, float eps = 1e-5 ,
100- cudaStream_t stream = 0 ) {
99+ cudaError_t RMSNorm (T* input , T* weight , T* output , uint32_t batch_size, uint32_t d,
100+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
101101 const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
102102
103103 const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
104104 const uint32_t num_warps = ceil_div (block_size, 32 );
105105 dim3 nblks (batch_size);
106106 dim3 nthrs (32 , num_warps);
107107 const uint32_t smem_size = num_warps * sizeof (float );
108- void * args[] = {&x , &w , &y , &d, &eps};
108+ void * args[] = {&input , &weight , &output , &d, &eps};
109109
110110 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
111111 auto kernel = RMSNormKernel<VEC_SIZE, T>;
@@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps
114114 return cudaSuccess;
115115}
116116
117+ template <uint32_t VEC_SIZE, typename T>
118+ __global__ void FusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
119+ T* __restrict__ weight, const uint32_t d, float eps) {
120+ const uint32_t bx = blockIdx .x ;
121+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
122+ constexpr uint32_t warp_size = 32 ;
123+ const uint32_t num_warps = blockDim .y ;
124+ const uint32_t thread_id = tx + ty * warp_size;
125+ const uint32_t num_threads = num_warps * warp_size;
126+ const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
127+ extern __shared__ float smem[];
128+
129+ float sum_sq = 0 .f ;
130+
131+ for (uint32_t i = 0 ; i < rounds; i++) {
132+ vec_t <T, VEC_SIZE> input_vec;
133+ input_vec.fill (0 );
134+ vec_t <T, VEC_SIZE> residual_vec;
135+ residual_vec.fill (0 );
136+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
137+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
138+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
139+ }
140+ #pragma unroll
141+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
142+ float x = float (input_vec[j]);
143+ x += float (residual_vec[j]);
144+ sum_sq += x * x;
145+ residual_vec[j] = (T)x;
146+ }
147+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
148+ residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
149+ }
150+ }
151+
152+ // first, warp reduce sum
153+ #pragma unroll
154+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
155+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
156+ }
157+
158+ smem[ty] = sum_sq;
159+ __syncthreads ();
160+ // then, cross warp reduce sum using only the first warp
161+ if (ty == 0 ) {
162+ sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
163+ #pragma unroll
164+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
165+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
166+ }
167+ smem[0 ] = sum_sq;
168+ }
169+ __syncthreads ();
170+
171+ float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
172+
173+ for (uint32_t i = 0 ; i < rounds; i++) {
174+ vec_t <T, VEC_SIZE> input_vec;
175+ vec_t <T, VEC_SIZE> weight_vec;
176+ vec_t <T, VEC_SIZE> residual_vec;
177+ input_vec.fill (0 );
178+ weight_vec.fill (0 );
179+ residual_vec.fill (0 );
180+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
181+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
182+ weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
183+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
184+ }
185+ #pragma unroll
186+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
187+ input_vec[j] = float (residual_vec[j]) * rms_rcp * float (weight_vec[j]);
188+ }
189+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
190+ input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
191+ }
192+ }
193+ }
194+
195+ template <typename T>
196+ cudaError_t FusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
197+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
198+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
199+
200+ const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
201+ const uint32_t num_warps = ceil_div (block_size, 32 );
202+ dim3 nblks (batch_size);
203+ dim3 nthrs (32 , num_warps);
204+ const uint32_t smem_size = num_warps * sizeof (float );
205+ void * args[] = {&input, &residual, &weight, &d, &eps};
206+
207+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
208+ auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
209+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
210+ });
211+
212+ return cudaSuccess;
213+ }
214+
117215} // namespace norm
118216
119217} // namespace flashinfer
0 commit comments