Skip to content

Commit b6d1035

Browse files
authored
[Kernel] Layernorm performance optimization (#3662)
1 parent 51c31bc commit b6d1035

File tree

4 files changed

+285
-47
lines changed

4 files changed

+285
-47
lines changed

cmake/utils.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
100100

101101
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
102102
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
103+
list(REMOVE_ITEM GPU_FLAGS
104+
"-D__CUDA_NO_HALF_OPERATORS__"
105+
"-D__CUDA_NO_HALF_CONVERSIONS__"
106+
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
107+
"-D__CUDA_NO_HALF2_OPERATORS__")
103108
endif()
104109

105110
elseif(${GPU_LANG} STREQUAL "HIP")

csrc/layernorm_kernels.cu

Lines changed: 250 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44

55
#include "dispatch_utils.h"
66
#include "reduction_utils.cuh"
7+
#ifndef USE_ROCM
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp16.h>
10+
#else
11+
#include <hip/hip_bf16.h>
12+
#include <hip/hip_fp16.h>
13+
14+
using __nv_bfloat16 = __hip_bfloat16;
15+
using __nv_bfloat162 = __hip_bfloat162;
16+
#endif
717

818
namespace vllm {
919

@@ -35,9 +45,199 @@ __global__ void rms_norm_kernel(
3545
}
3646
}
3747

38-
// TODO: Further optimize this kernel.
39-
template<typename scalar_t>
40-
__global__ void fused_add_rms_norm_kernel(
48+
49+
/* Converter structs for the conversion from torch types to HIP/CUDA types,
50+
and the associated type conversions within HIP/CUDA. These helpers need
51+
to be implemented for now because the relevant type conversion
52+
operators/constructors are not consistently implemented by HIP/CUDA, so
53+
a generic conversion via type casts cannot be implemented.
54+
55+
Each struct should have the member static constexpr bool `exists`:
56+
If false, the optimized kernel is not used for the corresponding torch type.
57+
If true, the struct should be fully defined as shown in the examples below.
58+
*/
59+
template<typename torch_type>
60+
struct _typeConvert { static constexpr bool exists = false; };
61+
62+
template<>
63+
struct _typeConvert<c10::Half> {
64+
static constexpr bool exists = true;
65+
using hip_type = __half;
66+
using packed_hip_type = __half2;
67+
68+
__device__ static inline float convert(hip_type x) { return __half2float(x); }
69+
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
70+
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
71+
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
72+
};
73+
74+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
75+
// CUDA_ARCH < 800 does not have BF16 support
76+
// TODO: Add in ROCm support once public headers handle bf16 maturely
77+
template<>
78+
struct _typeConvert<c10::BFloat16> {
79+
static constexpr bool exists = true;
80+
using hip_type = __nv_bfloat16;
81+
using packed_hip_type = __nv_bfloat162;
82+
83+
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
84+
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
85+
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
86+
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
87+
};
88+
#endif
89+
90+
91+
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
92+
for appropriate specializations of fused_add_rms_norm_kernel.
93+
Only functions that are necessary in that kernel are implemented.
94+
Alignment to 16 bytes is required to use 128-bit global memory ops.
95+
*/
96+
template<typename scalar_t, int width>
97+
struct alignas(16) _f16Vec {
98+
/* Not theoretically necessary that width is a power of 2 but should
99+
almost always be the case for optimization purposes */
100+
static_assert(width > 0 && (width & (width - 1)) == 0,
101+
"Width is not a positive power of 2!");
102+
using Converter = _typeConvert<scalar_t>;
103+
using T1 = typename Converter::hip_type;
104+
using T2 = typename Converter::packed_hip_type;
105+
T1 data[width];
106+
107+
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
108+
if constexpr (width % 2 == 0) {
109+
#pragma unroll
110+
for (int i = 0; i < width; i += 2) {
111+
T2 temp{data[i], data[i+1]};
112+
temp += T2{other.data[i], other.data[i+1]};
113+
data[i] = temp.x;
114+
data[i+1] = temp.y;
115+
}
116+
} else {
117+
#pragma unroll
118+
for (int i = 0; i < width; ++i)
119+
data[i] += other.data[i];
120+
}
121+
return *this;
122+
}
123+
124+
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
125+
if constexpr (width % 2 == 0) {
126+
#pragma unroll
127+
for (int i = 0; i < width; i += 2) {
128+
T2 temp{data[i], data[i+1]};
129+
temp *= T2{other.data[i], other.data[i+1]};
130+
data[i] = temp.x;
131+
data[i+1] = temp.y;
132+
}
133+
} else {
134+
#pragma unroll
135+
for (int i = 0; i < width; ++i)
136+
data[i] *= other.data[i];
137+
}
138+
return *this;
139+
}
140+
141+
__device__ _f16Vec& operator*=(const float scale) {
142+
if constexpr (width % 2 == 0) {
143+
#pragma unroll
144+
for (int i = 0; i < width; i += 2) {
145+
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
146+
temp_f.x *= scale;
147+
temp_f.y *= scale;
148+
T2 temp = Converter::convert(temp_f);
149+
data[i] = temp.x;
150+
data[i+1] = temp.y;
151+
}
152+
} else {
153+
#pragma unroll
154+
for (int i = 0; i < width; ++i) {
155+
float temp = Converter::convert(data[i]) * scale;
156+
data[i] = Converter::convert(temp);
157+
}
158+
}
159+
return *this;
160+
}
161+
162+
__device__ float sum_squares() const {
163+
float result = 0.0f;
164+
if constexpr (width % 2 == 0) {
165+
#pragma unroll
166+
for (int i = 0; i < width; i += 2) {
167+
float2 z = Converter::convert(T2{data[i], data[i+1]});
168+
result += z.x * z.x + z.y * z.y;
169+
}
170+
} else {
171+
#pragma unroll
172+
for (int i = 0; i < width; ++i) {
173+
float x = Converter::convert(data[i]);
174+
result += x * x;
175+
}
176+
}
177+
return result;
178+
}
179+
};
180+
181+
/* Function specialization in the case of FP16/BF16 tensors.
182+
Additional optimizations we can make in this case are
183+
packed and vectorized operations, which help with the
184+
memory latency bottleneck. */
185+
template<typename scalar_t, int width>
186+
__global__ std::enable_if_t<
187+
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
188+
scalar_t* __restrict__ input, // [..., hidden_size]
189+
scalar_t* __restrict__ residual, // [..., hidden_size]
190+
const scalar_t* __restrict__ weight, // [hidden_size]
191+
const float epsilon,
192+
const int num_tokens,
193+
const int hidden_size) {
194+
// Sanity checks on our vector struct and type-punned pointer arithmetic
195+
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
196+
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
197+
198+
const int vec_hidden_size = hidden_size / width;
199+
__shared__ float s_variance;
200+
float variance = 0.0f;
201+
/* These and the argument pointers are all declared `restrict` as they are
202+
not aliased in practice. Argument pointers should not be dereferenced
203+
in this kernel as that would be undefined behavior */
204+
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
205+
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
206+
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
207+
208+
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
209+
int id = blockIdx.x * vec_hidden_size + idx;
210+
_f16Vec<scalar_t, width> temp = input_v[id];
211+
temp += residual_v[id];
212+
variance += temp.sum_squares();
213+
residual_v[id] = temp;
214+
}
215+
/* Keep the following if-else block in sync with the
216+
calculation of max_block_size in fused_add_rms_norm */
217+
if (num_tokens < 256) {
218+
variance = blockReduceSum<float, 1024>(variance);
219+
} else variance = blockReduceSum<float, 256>(variance);
220+
if (threadIdx.x == 0) {
221+
s_variance = rsqrtf(variance / hidden_size + epsilon);
222+
}
223+
__syncthreads();
224+
225+
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
226+
int id = blockIdx.x * vec_hidden_size + idx;
227+
_f16Vec<scalar_t, width> temp = residual_v[id];
228+
temp *= s_variance;
229+
temp *= weight_v[idx];
230+
input_v[id] = temp;
231+
}
232+
}
233+
234+
235+
/* Generic fused_add_rms_norm_kernel
236+
The width field is not used here but necessary for other specializations.
237+
*/
238+
template<typename scalar_t, int width>
239+
__global__ std::enable_if_t<
240+
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
41241
scalar_t* __restrict__ input, // [..., hidden_size]
42242
scalar_t* __restrict__ residual, // [..., hidden_size]
43243
const scalar_t* __restrict__ weight, // [hidden_size]
@@ -48,12 +248,17 @@ __global__ void fused_add_rms_norm_kernel(
48248
float variance = 0.0f;
49249

50250
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
51-
float x = (float) input[blockIdx.x * hidden_size + idx];
52-
x += (float) residual[blockIdx.x * hidden_size + idx];
251+
scalar_t z = input[blockIdx.x * hidden_size + idx];
252+
z += residual[blockIdx.x * hidden_size + idx];
253+
float x = (float) z;
53254
variance += x * x;
54-
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
255+
residual[blockIdx.x * hidden_size + idx] = z;
55256
}
56-
variance = blockReduceSum<float>(variance);
257+
/* Keep the following if-else block in sync with the
258+
calculation of max_block_size in fused_add_rms_norm */
259+
if (num_tokens < 256) {
260+
variance = blockReduceSum<float, 1024>(variance);
261+
} else variance = blockReduceSum<float, 256>(variance);
57262
if (threadIdx.x == 0) {
58263
s_variance = rsqrtf(variance / hidden_size + epsilon);
59264
}
@@ -93,6 +298,21 @@ void rms_norm(
93298
});
94299
}
95300

301+
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
302+
VLLM_DISPATCH_FLOATING_TYPES( \
303+
input.scalar_type(), \
304+
"fused_add_rms_norm_kernel", \
305+
[&] { \
306+
vllm::fused_add_rms_norm_kernel \
307+
<scalar_t, width><<<grid, block, 0, stream>>>( \
308+
input.data_ptr<scalar_t>(), \
309+
residual.data_ptr<scalar_t>(), \
310+
weight.data_ptr<scalar_t>(), \
311+
epsilon, \
312+
num_tokens, \
313+
hidden_size); \
314+
});
315+
96316
void fused_add_rms_norm(
97317
torch::Tensor& input, // [..., hidden_size]
98318
torch::Tensor& residual, // [..., hidden_size]
@@ -102,19 +322,29 @@ void fused_add_rms_norm(
102322
int num_tokens = input.numel() / hidden_size;
103323

104324
dim3 grid(num_tokens);
105-
dim3 block(std::min(hidden_size, 1024));
325+
/* This kernel is memory-latency bound in many scenarios.
326+
When num_tokens is large, a smaller block size allows
327+
for increased block occupancy on CUs and better latency
328+
hiding on global mem ops. */
329+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
330+
dim3 block(std::min(hidden_size, max_block_size));
106331
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
107332
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
108-
VLLM_DISPATCH_FLOATING_TYPES(
109-
input.scalar_type(),
110-
"fused_add_rms_norm_kernel",
111-
[&] {
112-
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
113-
input.data_ptr<scalar_t>(),
114-
residual.data_ptr<scalar_t>(),
115-
weight.data_ptr<scalar_t>(),
116-
epsilon,
117-
num_tokens,
118-
hidden_size);
119-
});
333+
/*If the tensor types are FP16/BF16, try to use the optimized kernel
334+
with packed + vectorized ops.
335+
Max optimization is achieved with a width-8 vector of FP16/BF16s
336+
since we can load at most 128 bits at once in a global memory op.
337+
However, this requires each tensor's data to be aligned to 16
338+
bytes.
339+
*/
340+
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
341+
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
342+
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
343+
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
344+
&& wt_ptr % 16 == 0;
345+
if (ptrs_are_aligned && hidden_size % 8 == 0) {
346+
LAUNCH_FUSED_ADD_RMS_NORM(8);
347+
} else {
348+
LAUNCH_FUSED_ADD_RMS_NORM(0);
349+
}
120350
}

csrc/reduction_utils.cuh

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,45 @@
2020
#include "cuda_compat.h"
2121

2222
namespace vllm {
23-
24-
template<typename T>
23+
template<typename T, int numLanes = WARP_SIZE>
2524
__inline__ __device__ T warpReduceSum(T val) {
26-
#pragma unroll
27-
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
25+
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
26+
"numLanes is not a positive power of 2!");
27+
static_assert(numLanes <= WARP_SIZE);
28+
#pragma unroll
29+
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
2830
val += VLLM_SHFL_XOR_SYNC(val, mask);
2931
return val;
3032
}
3133

32-
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
33-
return warp_size - 1;
34-
}
35-
36-
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
37-
return 5 + (warp_size >> 6);
34+
// Helper function to return the next largest power of 2
35+
static constexpr int _nextPow2(unsigned int num) {
36+
if (num <= 1) return num;
37+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
3838
}
3939

4040
/* Calculate the sum of all elements in a block */
41-
template<typename T>
41+
template<typename T, int maxBlockSize = 1024>
4242
__inline__ __device__ T blockReduceSum(T val) {
43-
static __shared__ T shared[WARP_SIZE];
44-
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
45-
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
46-
int lane = threadIdx.x & LANE_MASK;
47-
int wid = threadIdx.x >> WID_SHIFT;
48-
49-
val = warpReduceSum<T>(val);
50-
51-
if (lane == 0)
52-
shared[wid] = val;
43+
static_assert(maxBlockSize <= 1024);
44+
if constexpr (maxBlockSize > WARP_SIZE) {
45+
val = warpReduceSum<T>(val);
46+
// Calculates max number of lanes that need to participate in the last warpReduce
47+
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
48+
static __shared__ T shared[maxActiveLanes];
49+
int lane = threadIdx.x % WARP_SIZE;
50+
int wid = threadIdx.x / WARP_SIZE;
51+
if (lane == 0)
52+
shared[wid] = val;
5353

54-
__syncthreads();
54+
__syncthreads();
5555

56-
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
57-
// blockDim.x is not divided by 32
58-
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
59-
val = warpReduceSum<T>(val);
56+
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
57+
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
58+
} else {
59+
// A single warpReduce is equal to blockReduce
60+
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
61+
}
6062
return val;
6163
}
6264

0 commit comments

Comments
 (0)