4
4
5
5
#include " dispatch_utils.h"
6
6
#include " reduction_utils.cuh"
7
+ #ifndef USE_ROCM
8
+ #include < cuda_bf16.h>
9
+ #include < cuda_fp16.h>
10
+ #else
11
+ #include < hip/hip_bf16.h>
12
+ #include < hip/hip_fp16.h>
13
+
14
+ using __nv_bfloat16 = __hip_bfloat16;
15
+ using __nv_bfloat162 = __hip_bfloat162;
16
+ #endif
7
17
8
18
namespace vllm {
9
19
@@ -35,9 +45,199 @@ __global__ void rms_norm_kernel(
35
45
}
36
46
}
37
47
38
- // TODO: Further optimize this kernel.
39
- template <typename scalar_t >
40
- __global__ void fused_add_rms_norm_kernel (
48
+
49
+ /* Converter structs for the conversion from torch types to HIP/CUDA types,
50
+ and the associated type conversions within HIP/CUDA. These helpers need
51
+ to be implemented for now because the relevant type conversion
52
+ operators/constructors are not consistently implemented by HIP/CUDA, so
53
+ a generic conversion via type casts cannot be implemented.
54
+
55
+ Each struct should have the member static constexpr bool `exists`:
56
+ If false, the optimized kernel is not used for the corresponding torch type.
57
+ If true, the struct should be fully defined as shown in the examples below.
58
+ */
59
+ template <typename torch_type>
60
+ struct _typeConvert { static constexpr bool exists = false ; };
61
+
62
+ template <>
63
+ struct _typeConvert <c10::Half> {
64
+ static constexpr bool exists = true ;
65
+ using hip_type = __half;
66
+ using packed_hip_type = __half2;
67
+
68
+ __device__ static inline float convert (hip_type x) { return __half2float (x); }
69
+ __device__ static inline float2 convert (packed_hip_type x) { return __half22float2 (x); }
70
+ __device__ static inline hip_type convert (float x) { return __float2half_rn (x); }
71
+ __device__ static inline packed_hip_type convert (float2 x) { return __float22half2_rn (x); }
72
+ };
73
+
74
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
75
+ // CUDA_ARCH < 800 does not have BF16 support
76
+ // TODO: Add in ROCm support once public headers handle bf16 maturely
77
+ template <>
78
+ struct _typeConvert <c10::BFloat16> {
79
+ static constexpr bool exists = true ;
80
+ using hip_type = __nv_bfloat16;
81
+ using packed_hip_type = __nv_bfloat162;
82
+
83
+ __device__ static inline float convert (hip_type x) { return __bfloat162float (x); }
84
+ __device__ static inline float2 convert (packed_hip_type x) { return __bfloat1622float2 (x); }
85
+ __device__ static inline hip_type convert (float x) { return __float2bfloat16 (x); }
86
+ __device__ static inline packed_hip_type convert (float2 x) { return __float22bfloat162_rn (x); }
87
+ };
88
+ #endif
89
+
90
+
91
+ /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
92
+ for appropriate specializations of fused_add_rms_norm_kernel.
93
+ Only functions that are necessary in that kernel are implemented.
94
+ Alignment to 16 bytes is required to use 128-bit global memory ops.
95
+ */
96
+ template <typename scalar_t , int width>
97
+ struct alignas (16 ) _f16Vec {
98
+ /* Not theoretically necessary that width is a power of 2 but should
99
+ almost always be the case for optimization purposes */
100
+ static_assert (width > 0 && (width & (width - 1 )) == 0 ,
101
+ " Width is not a positive power of 2!" );
102
+ using Converter = _typeConvert<scalar_t >;
103
+ using T1 = typename Converter::hip_type;
104
+ using T2 = typename Converter::packed_hip_type;
105
+ T1 data[width];
106
+
107
+ __device__ _f16Vec& operator +=(const _f16Vec<scalar_t , width>& other) {
108
+ if constexpr (width % 2 == 0 ) {
109
+ #pragma unroll
110
+ for (int i = 0 ; i < width; i += 2 ) {
111
+ T2 temp{data[i], data[i+1 ]};
112
+ temp += T2{other.data [i], other.data [i+1 ]};
113
+ data[i] = temp.x ;
114
+ data[i+1 ] = temp.y ;
115
+ }
116
+ } else {
117
+ #pragma unroll
118
+ for (int i = 0 ; i < width; ++i)
119
+ data[i] += other.data [i];
120
+ }
121
+ return *this ;
122
+ }
123
+
124
+ __device__ _f16Vec& operator *=(const _f16Vec<scalar_t , width>& other) {
125
+ if constexpr (width % 2 == 0 ) {
126
+ #pragma unroll
127
+ for (int i = 0 ; i < width; i += 2 ) {
128
+ T2 temp{data[i], data[i+1 ]};
129
+ temp *= T2{other.data [i], other.data [i+1 ]};
130
+ data[i] = temp.x ;
131
+ data[i+1 ] = temp.y ;
132
+ }
133
+ } else {
134
+ #pragma unroll
135
+ for (int i = 0 ; i < width; ++i)
136
+ data[i] *= other.data [i];
137
+ }
138
+ return *this ;
139
+ }
140
+
141
+ __device__ _f16Vec& operator *=(const float scale) {
142
+ if constexpr (width % 2 == 0 ) {
143
+ #pragma unroll
144
+ for (int i = 0 ; i < width; i += 2 ) {
145
+ float2 temp_f = Converter::convert (T2{data[i], data[i+1 ]});
146
+ temp_f.x *= scale;
147
+ temp_f.y *= scale;
148
+ T2 temp = Converter::convert (temp_f);
149
+ data[i] = temp.x ;
150
+ data[i+1 ] = temp.y ;
151
+ }
152
+ } else {
153
+ #pragma unroll
154
+ for (int i = 0 ; i < width; ++i) {
155
+ float temp = Converter::convert (data[i]) * scale;
156
+ data[i] = Converter::convert (temp);
157
+ }
158
+ }
159
+ return *this ;
160
+ }
161
+
162
+ __device__ float sum_squares () const {
163
+ float result = 0 .0f ;
164
+ if constexpr (width % 2 == 0 ) {
165
+ #pragma unroll
166
+ for (int i = 0 ; i < width; i += 2 ) {
167
+ float2 z = Converter::convert (T2{data[i], data[i+1 ]});
168
+ result += z.x * z.x + z.y * z.y ;
169
+ }
170
+ } else {
171
+ #pragma unroll
172
+ for (int i = 0 ; i < width; ++i) {
173
+ float x = Converter::convert (data[i]);
174
+ result += x * x;
175
+ }
176
+ }
177
+ return result;
178
+ }
179
+ };
180
+
181
+ /* Function specialization in the case of FP16/BF16 tensors.
182
+ Additional optimizations we can make in this case are
183
+ packed and vectorized operations, which help with the
184
+ memory latency bottleneck. */
185
+ template <typename scalar_t , int width>
186
+ __global__ std::enable_if_t <
187
+ (width > 0 ) && _typeConvert<scalar_t >::exists> fused_add_rms_norm_kernel (
188
+ scalar_t * __restrict__ input, // [..., hidden_size]
189
+ scalar_t * __restrict__ residual, // [..., hidden_size]
190
+ const scalar_t * __restrict__ weight, // [hidden_size]
191
+ const float epsilon,
192
+ const int num_tokens,
193
+ const int hidden_size) {
194
+ // Sanity checks on our vector struct and type-punned pointer arithmetic
195
+ static_assert (std::is_pod_v<_f16Vec<scalar_t , width>>);
196
+ static_assert (sizeof (_f16Vec<scalar_t , width>) == sizeof (scalar_t ) * width);
197
+
198
+ const int vec_hidden_size = hidden_size / width;
199
+ __shared__ float s_variance;
200
+ float variance = 0 .0f ;
201
+ /* These and the argument pointers are all declared `restrict` as they are
202
+ not aliased in practice. Argument pointers should not be dereferenced
203
+ in this kernel as that would be undefined behavior */
204
+ auto * __restrict__ input_v = reinterpret_cast <_f16Vec<scalar_t , width>*>(input);
205
+ auto * __restrict__ residual_v = reinterpret_cast <_f16Vec<scalar_t , width>*>(residual);
206
+ auto * __restrict__ weight_v = reinterpret_cast <const _f16Vec<scalar_t , width>*>(weight);
207
+
208
+ for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
209
+ int id = blockIdx .x * vec_hidden_size + idx;
210
+ _f16Vec<scalar_t , width> temp = input_v[id];
211
+ temp += residual_v[id];
212
+ variance += temp.sum_squares ();
213
+ residual_v[id] = temp;
214
+ }
215
+ /* Keep the following if-else block in sync with the
216
+ calculation of max_block_size in fused_add_rms_norm */
217
+ if (num_tokens < 256 ) {
218
+ variance = blockReduceSum<float , 1024 >(variance);
219
+ } else variance = blockReduceSum<float , 256 >(variance);
220
+ if (threadIdx .x == 0 ) {
221
+ s_variance = rsqrtf (variance / hidden_size + epsilon);
222
+ }
223
+ __syncthreads ();
224
+
225
+ for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
226
+ int id = blockIdx .x * vec_hidden_size + idx;
227
+ _f16Vec<scalar_t , width> temp = residual_v[id];
228
+ temp *= s_variance;
229
+ temp *= weight_v[idx];
230
+ input_v[id] = temp;
231
+ }
232
+ }
233
+
234
+
235
+ /* Generic fused_add_rms_norm_kernel
236
+ The width field is not used here but necessary for other specializations.
237
+ */
238
+ template <typename scalar_t , int width>
239
+ __global__ std::enable_if_t <
240
+ (width == 0 ) || !_typeConvert<scalar_t >::exists> fused_add_rms_norm_kernel (
41
241
scalar_t * __restrict__ input, // [..., hidden_size]
42
242
scalar_t * __restrict__ residual, // [..., hidden_size]
43
243
const scalar_t * __restrict__ weight, // [hidden_size]
@@ -48,12 +248,17 @@ __global__ void fused_add_rms_norm_kernel(
48
248
float variance = 0 .0f ;
49
249
50
250
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
51
- float x = (float ) input[blockIdx .x * hidden_size + idx];
52
- x += (float ) residual[blockIdx .x * hidden_size + idx];
251
+ scalar_t z = input[blockIdx .x * hidden_size + idx];
252
+ z += residual[blockIdx .x * hidden_size + idx];
253
+ float x = (float ) z;
53
254
variance += x * x;
54
- residual[blockIdx .x * hidden_size + idx] = ( scalar_t ) x ;
255
+ residual[blockIdx .x * hidden_size + idx] = z ;
55
256
}
56
- variance = blockReduceSum<float >(variance);
257
+ /* Keep the following if-else block in sync with the
258
+ calculation of max_block_size in fused_add_rms_norm */
259
+ if (num_tokens < 256 ) {
260
+ variance = blockReduceSum<float , 1024 >(variance);
261
+ } else variance = blockReduceSum<float , 256 >(variance);
57
262
if (threadIdx .x == 0 ) {
58
263
s_variance = rsqrtf (variance / hidden_size + epsilon);
59
264
}
@@ -93,6 +298,21 @@ void rms_norm(
93
298
});
94
299
}
95
300
301
+ #define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
302
+ VLLM_DISPATCH_FLOATING_TYPES ( \
303
+ input.scalar_type(), \
304
+ "fused_add_rms_norm_kernel", \
305
+ [&] { \
306
+ vllm::fused_add_rms_norm_kernel \
307
+ <scalar_t , width><<<grid, block, 0 , stream>>> ( \
308
+ input.data_ptr <scalar_t >(), \
309
+ residual.data_ptr <scalar_t >(), \
310
+ weight.data_ptr <scalar_t >(), \
311
+ epsilon, \
312
+ num_tokens, \
313
+ hidden_size); \
314
+ });
315
+
96
316
void fused_add_rms_norm (
97
317
torch::Tensor& input, // [..., hidden_size]
98
318
torch::Tensor& residual, // [..., hidden_size]
@@ -102,19 +322,29 @@ void fused_add_rms_norm(
102
322
int num_tokens = input.numel () / hidden_size;
103
323
104
324
dim3 grid (num_tokens);
105
- dim3 block (std::min (hidden_size, 1024 ));
325
+ /* This kernel is memory-latency bound in many scenarios.
326
+ When num_tokens is large, a smaller block size allows
327
+ for increased block occupancy on CUs and better latency
328
+ hiding on global mem ops. */
329
+ const int max_block_size = (num_tokens < 256 ) ? 1024 : 256 ;
330
+ dim3 block (std::min (hidden_size, max_block_size));
106
331
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
107
332
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
108
- VLLM_DISPATCH_FLOATING_TYPES (
109
- input.scalar_type (),
110
- " fused_add_rms_norm_kernel" ,
111
- [&] {
112
- vllm::fused_add_rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
113
- input.data_ptr <scalar_t >(),
114
- residual.data_ptr <scalar_t >(),
115
- weight.data_ptr <scalar_t >(),
116
- epsilon,
117
- num_tokens,
118
- hidden_size);
119
- });
333
+ /* If the tensor types are FP16/BF16, try to use the optimized kernel
334
+ with packed + vectorized ops.
335
+ Max optimization is achieved with a width-8 vector of FP16/BF16s
336
+ since we can load at most 128 bits at once in a global memory op.
337
+ However, this requires each tensor's data to be aligned to 16
338
+ bytes.
339
+ */
340
+ auto inp_ptr = reinterpret_cast <std::uintptr_t >(input.data_ptr ());
341
+ auto res_ptr = reinterpret_cast <std::uintptr_t >(residual.data_ptr ());
342
+ auto wt_ptr = reinterpret_cast <std::uintptr_t >(weight.data_ptr ());
343
+ bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
344
+ && wt_ptr % 16 == 0 ;
345
+ if (ptrs_are_aligned && hidden_size % 8 == 0 ) {
346
+ LAUNCH_FUSED_ADD_RMS_NORM (8 );
347
+ } else {
348
+ LAUNCH_FUSED_ADD_RMS_NORM (0 );
349
+ }
120
350
}
0 commit comments