1010
1111#include " block_reduce.h"
1212#include " ../common/micros.h"
13- #include " utils/cuda_type_utils.h"
13+ #include " funcs/cast_functor.h"
14+ #include " funcs/op_functor.h"
1415
1516using colossalAI::cuda::utils::block_reduce;
1617using colossalAI::cuda::utils::ReduceType;
18+ using colossalAI::cuda::funcs::TypeConverter;
19+ using colossalAI::cuda::funcs::CastFunctor;
20+ using colossalAI::cuda::funcs::BinaryOpFunctor;
21+ using colossalAI::cuda::funcs::BinaryOpType;
1722
1823#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT (DATA_SIZE, TYPE, NAME, ...) \
1924 if (DATA_SIZE == 2 ) { \
@@ -53,6 +58,9 @@ __global__ void rms_layernorm_kernel(
5358 const int num_tokens,
5459 const int hidden_size) {
5560 using scalar2_t = typename TypeConverter<scalar_t >::Type;
61+ BinaryOpFunctor<scalar2_t , scalar2_t , scalar2_t , BinaryOpType::kMul > mul_scalar2t ();
62+ CastFunctor<scalar2_t , float > cast_scalar2t_2_float ();
63+ CastFunctor<float , scalar2_t > cast_float_2_scalar2t ();
5664 __shared__ float s_variance;
5765
5866 /*
@@ -72,12 +80,13 @@ __global__ void rms_layernorm_kernel(
7280 float variance = 0 .0f ;
7381 int row_offset = blockIdx .x * hidden_size / 2 ;
7482
83+
7584#pragma unroll unroll_factor
7685 for (int idx = threadIdx .x , cnt = 0 ; idx < hidden_size / 2 ; idx += blockDim .x , cnt++) {
7786 int id = row_offset + idx;
7887 x_local[cnt] = input_ptr[id];
79- float v1 = cuda_cast< float > (x_local[cnt].x );
80- float v2 = cuda_cast< float > (x_local[cnt].y );
88+ float v1 = cast_scalar2t_2_float (x_local[cnt].x );
89+ float v2 = cast_scalar2t_2_float (x_local[cnt].y );
8190 variance += v1 * v1 + v2 * v2;
8291 }
8392 block_reduce<float , ReduceType::kSum ,1 >(&variance);
@@ -86,11 +95,11 @@ __global__ void rms_layernorm_kernel(
8695 }
8796 __syncthreads ();
8897
89- scalar2_t s_variance_2 = cuda_cast< scalar2_t > (s_variance);
98+ scalar2_t s_variance_2 = cast_float_2_scalar2t (s_variance);
9099#pragma unroll unroll_factor
91100 for (int idx = threadIdx .x , cnt = 0 ; idx < hidden_size / 2 ; idx += blockDim .x , cnt++) {
92101 int id = row_offset + idx;
93- out_ptr[id] = mul ( x_local[cnt], s_variance_2, weight_ptr[idx]);
102+ out_ptr[id] = mul_scalar2t ( mul_scalar2t ( x_local[cnt], s_variance_2) , weight_ptr[idx]);
94103 }
95104}
96105
@@ -137,6 +146,11 @@ __global__ void fused_add_rms_layernorm_kernel(
137146 const int num_tokens,
138147 const int hidden_size) {
139148 using scalar2_t = typename TypeConverter<scalar_t >::Type;
149+ BinaryOpFunctor<scalar2_t , scalar2_t , scalar2_t , BinaryOpType::kAdd > add_scalar2t ();
150+ CastFunctor<scalar2_t , float > cast_scalar2t_2_float ();
151+ CastFunctor<float , scalar2_t > cast_float_2_scalar2t ();
152+ BinaryOpFunctor<scalar2_t , scalar2_t , scalar2_t , BinaryOpType::kMul > mul_scalar2t ();
153+
140154 __shared__ float s_variance;
141155 scalar2_t x_local[4 ];
142156
@@ -151,9 +165,9 @@ __global__ void fused_add_rms_layernorm_kernel(
151165 for (int idx = threadIdx .x , cnt = 0 ; idx < hidden_size / 2 ; idx += blockDim .x , cnt++) {
152166 int id = row_offset + idx;
153167 x_local[cnt] = input_ptr[id];
154- x_local[cnt] = add (x_local[cnt], residual_ptr[id]);
155- float v1 = cuda_cast< float > (x_local[cnt].x );
156- float v2 = cuda_cast< float > (x_local[cnt].y );
168+ x_local[cnt] = add_scalar2t (x_local[cnt], residual_ptr[id]);
169+ float v1 = cast_scalar2t_2_float (x_local[cnt].x );
170+ float v2 = cast_scalar2t_2_float (x_local[cnt].y );
157171 variance += v1 * v1 + v2 * v2;
158172 residual_ptr[id] = x_local[cnt];
159173 }
@@ -163,11 +177,11 @@ __global__ void fused_add_rms_layernorm_kernel(
163177 }
164178 __syncthreads ();
165179
166- scalar2_t s_variance_2 = cuda_cast< scalar2_t > (s_variance);
180+ scalar2_t s_variance_2 = cast_float_2_scalar2t (s_variance);
167181#pragma unroll unroll_factor
168182 for (int idx = threadIdx .x , cnt = 0 ; idx < hidden_size / 2 ; idx += blockDim .x , cnt++) {
169183 int id = row_offset + idx;
170- input_ptr[id] = mul ( x_local[cnt], s_variance_2, weight_ptr[idx]);
184+ input_ptr[id] = mul_scalar2t ( mul_scalar2t ( x_local[cnt], s_variance_2) , weight_ptr[idx]);
171185 }
172186}
173187
0 commit comments