Skip to content

Commit 8b5f18b

Browse files
authored
feat(vit_cuda_kernels):add norm quant and some fused ops (#886)
# vit fp8w8a8量化推理相关算子优化 ## 新增算子 1. rmsnorm_bf16,性能较pytorch较大提升 2. pre_tp_norm,融合了tp_norm的通信前操作 3. post_tp_norm,融合了tp_norm的通信后操作 4. pre_token_quant,逐token FP8量化,性能较vllm的quant极大提升,较sgl的quant性能更好 5. gelu_per_token_quant,融合了GELU激活 + 逐token FP8量化 6. add_norm_quant,融合了attention与mlp模块间的,add norm quant操作 7. cutlass_scaled_mm_bias_ls,融合了量化矩阵乘、反量化和可选的bias和ls weight
1 parent 009e972 commit 8b5f18b

File tree

702 files changed

+554067
-112
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

702 files changed

+554067
-112
lines changed

lightllm-kernel/csrc/fusion/add_norm_quant.cu

Lines changed: 551 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
#include "ops_common.h"
2+
#include "reduce/sm70.cuh"
3+
4+
5+
namespace lightllm {
6+
namespace ops {
7+
8+
using namespace lightllm;
9+
10+
template<int32_t TPB, int32_t N>
11+
__global__ void device_gelu_per_token_quant_bf16_to_fp8(
12+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
13+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
14+
fp32_t* __restrict__ scales, // Output scales for each group
15+
const int64_t M // Number of rows in the input tensor
16+
) {
17+
constexpr int32_t VPT = 8;
18+
19+
static_assert(N % 2 == 0, "N must be even.");
20+
static_assert(N % VPT == 0, "N must be a multiple of VPT.");
21+
22+
const int32_t bid = blockIdx.x;
23+
const int32_t tid = threadIdx.x;
24+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
25+
const bf16x2_t one = _float22bf162_rn(make_float2(1.0f, 1.0f));
26+
const bf16x2_t one_2 = _float22bf162_rn(make_float2(0.5f, 0.5f));
27+
28+
const bf16_t* _input = input + bid * N; // Input pointer for the group
29+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group
30+
31+
fp32_t* _scales;
32+
_scales = scales + bid;
33+
34+
// Local arrays for intermediate storage
35+
fp8x4_e4m3_t local_f8[VPT / 4];
36+
bf16x2_t local_bf16[VPT / 2];
37+
38+
__shared__ bf16x2_t workspace[N / 2];
39+
40+
fp32_t local_max = -FLT_MAX;
41+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
42+
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);
43+
//gelu
44+
#pragma unroll
45+
for(int32_t j = 0; j< VPT/2; j++){
46+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
47+
tmp.x = erf(tmp.x * 0.7071067811f);
48+
tmp.y = erf(tmp.y * 0.7071067811f);
49+
bf16x2_t tan = _float22bf162_rn(tmp);
50+
tan = __hadd2(tan, one);
51+
tan = __hmul2(tan, local_bf16[j]);
52+
tan = __hmul2(tan, one_2);
53+
local_bf16[j] = tan;
54+
}
55+
56+
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1));
57+
58+
#pragma unroll
59+
for(int32_t j = 0; j< VPT/2; j++){
60+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
61+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
62+
local_max = fmaxf(local_max, max);
63+
}
64+
}
65+
66+
// Reduce the maximum value across the thread group
67+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
68+
69+
// Compute the scale factor with epsilon to avoid division by zero
70+
constexpr fp32_t epsilon = 1e-7f;
71+
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
72+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
73+
74+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
75+
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16);
76+
77+
#pragma unroll
78+
for (int32_t j = 0; j < VPT/4; j++) {
79+
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]);
80+
fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]);
81+
fp32x4_t ret = make_float4(
82+
x.x * inv_scale,
83+
x.y * inv_scale,
84+
y.x * inv_scale,
85+
y.y * inv_scale
86+
);
87+
local_f8[j] = fp8x4_e4m3_t(ret);
88+
}
89+
90+
vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i);
91+
}
92+
93+
if(tid == 0){
94+
*_scales = scale;
95+
}
96+
}
97+
98+
99+
template<int32_t TPB>
100+
__global__ void gelu_per_token_quant_bf16_to_fp8_vpt(
101+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
102+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
103+
fp32_t* __restrict__ scales, // Output scales for each group
104+
const int64_t M, // Number of rows in the input tensor
105+
const int32_t N
106+
) {
107+
constexpr int32_t VPT = 8;
108+
109+
const int32_t bid = blockIdx.x;
110+
const int32_t tid = threadIdx.x;
111+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
112+
constexpr fp32_t sqrt_2_over_pi = 0.7978845608028654f;
113+
constexpr fp32_t coeff = 0.044715f;
114+
115+
const bf16_t* _input = input + bid * N; // Input pointer for the group
116+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group
117+
118+
fp32_t* _scales;
119+
_scales = scales + bid;
120+
121+
// Local arrays for intermediate storage
122+
fp8x4_e4m3_t local_f8[VPT / 4];
123+
bf16x2_t local_bf16[VPT / 2];
124+
125+
extern __shared__ bf16x2_t workspace[];
126+
127+
fp32_t local_max = -FLT_MAX;
128+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
129+
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);
130+
131+
#pragma unroll
132+
for(int32_t j = 0; j< VPT/2; j++){
133+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
134+
135+
fp32_t tanh_arg1 = sqrt_2_over_pi * (tmp.x + coeff * tmp.x * tmp.x * tmp.x);
136+
fp32_t tanh_arg2 = sqrt_2_over_pi * (tmp.y + coeff * tmp.y * tmp.y * tmp.y);
137+
tmp.x = 0.5f * tmp.x * (1.0f + tanhf(tanh_arg1));
138+
tmp.y = 0.5f * tmp.y * (1.0f + tanhf(tanh_arg2));
139+
140+
local_bf16[j] = _float22bf162_rn(tmp);
141+
}
142+
143+
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1));
144+
145+
// Compute the max for the VPT elements.
146+
#pragma unroll
147+
for(int32_t j = 0; j< VPT/2; j++){
148+
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
149+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
150+
local_max = fmaxf(local_max, max);
151+
}
152+
}
153+
154+
// Reduce the maximum value across the thread group
155+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
156+
157+
// Compute the scale factor with epsilon to avoid division by zero
158+
constexpr fp32_t epsilon = 1e-7f;
159+
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
160+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
161+
162+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
163+
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16);
164+
165+
#pragma unroll
166+
for (int32_t j = 0; j < VPT/4; j++) {
167+
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[2 * j + 0]);
168+
fp32x2_t y = bf16x2_to_fp32x2(local_bf16[2 * j + 1]);
169+
fp32x4_t ret = make_float4(
170+
x.x * inv_scale,
171+
x.y * inv_scale,
172+
y.x * inv_scale,
173+
y.y * inv_scale
174+
);
175+
local_f8[j] = fp8x4_e4m3_t(ret);
176+
}
177+
178+
vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i);
179+
}
180+
181+
if(tid == 0){
182+
*_scales = scale;
183+
}
184+
}
185+
186+
187+
template<int32_t TPB>
188+
__global__ void gelu_per_token_quant_bf16_to_fp8_general(
189+
const bf16_t* __restrict__ input, // Input tensor in BF16 format
190+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
191+
fp32_t* __restrict__ scales, // Output scales for each group
192+
const int64_t M, // Number of rows in the input tensor
193+
const int32_t N
194+
) {
195+
const int32_t bid = blockIdx.x;
196+
const int32_t tid = threadIdx.x;
197+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
198+
constexpr fp32_t sqrt_2_over_pi = 0.7978845608028654f;
199+
constexpr fp32_t coeff = 0.044715f;
200+
201+
const bf16_t* _input = input + bid * N; // Input pointer for the group
202+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the group
203+
204+
fp32_t* _scales;
205+
_scales = scales + bid;
206+
207+
extern __shared__ bf16_t workspace_[];
208+
209+
fp32_t local_max = -FLT_MAX;
210+
211+
for (int32_t i = tid; i < N; i += TPB) {
212+
fp32_t tmp = cvt_bf16_f32(_input[i]);
213+
fp32_t tanh_arg = sqrt_2_over_pi * (tmp + coeff * tmp * tmp * tmp);
214+
tmp = 0.5f * tmp * (1.0f + tanhf(tanh_arg));
215+
local_max = fmaxf(local_max, fabsf(tmp));
216+
workspace_[i] = cvt_f32_bf16(tmp);
217+
}
218+
219+
// Reduce the maximum value across the thread group
220+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
221+
222+
// Compute the scale factor with epsilon to avoid division by zero
223+
constexpr fp32_t epsilon = 1e-7f;
224+
const fp32_t scale = reduced_max / FP8_E4M3_MAX;
225+
const fp32_t inv_scale = 1.0f / (scale + epsilon);
226+
227+
for (int32_t i = tid; i < N; i += TPB) {
228+
// Load the previously stored vectorized data from shared memory.
229+
fp32_t x = cvt_bf16_f32(workspace_[i]);
230+
// Apply normalization: multiply by inv_norm and then scale by the weight.
231+
fp32_t ret = x * inv_scale;
232+
_output[i] = fp8_e4m3_t(ret);
233+
}
234+
235+
if(tid == 0){
236+
*_scales = scale;
237+
}
238+
}
239+
240+
void gelu_per_token_quant_bf16_fp8 (
241+
Tensor& output,
242+
const Tensor& input,
243+
Tensor& scales
244+
) {
245+
TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
246+
TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional");
247+
TORCH_CHECK(input.scalar_type() == c10::kBFloat16, "Input must be BF16 type");
248+
249+
Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous();
250+
Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous();
251+
252+
const int64_t M = input.size(0);
253+
const int64_t N = input.size(1);
254+
255+
const int32_t blocks = M;
256+
257+
switch (N) {
258+
case 16:
259+
device_gelu_per_token_quant_bf16_to_fp8<64, 16>
260+
<<<blocks, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
261+
PTR<bf16_t>(contiguous_input),
262+
PTR<fp8_e4m3_t>(output),
263+
PTR<fp32_t>(contiguous_scales),
264+
M
265+
);
266+
break;
267+
case 32:
268+
device_gelu_per_token_quant_bf16_to_fp8<64, 32>
269+
<<<blocks, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
270+
PTR<bf16_t>(contiguous_input),
271+
PTR<fp8_e4m3_t>(output),
272+
PTR<fp32_t>(contiguous_scales),
273+
M
274+
);
275+
break;
276+
case 64:
277+
device_gelu_per_token_quant_bf16_to_fp8<64, 64>
278+
<<<blocks, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
279+
PTR<bf16_t>(contiguous_input),
280+
PTR<fp8_e4m3_t>(output),
281+
PTR<fp32_t>(contiguous_scales),
282+
M
283+
);
284+
break;
285+
case 512:
286+
device_gelu_per_token_quant_bf16_to_fp8<64, 512>
287+
<<<blocks, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
288+
PTR<bf16_t>(contiguous_input),
289+
PTR<fp8_e4m3_t>(output),
290+
PTR<fp32_t>(contiguous_scales),
291+
M
292+
);
293+
break;
294+
295+
case 1024:
296+
device_gelu_per_token_quant_bf16_to_fp8<128, 1024>
297+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
298+
PTR<bf16_t>(contiguous_input),
299+
PTR<fp8_e4m3_t>(output),
300+
PTR<fp32_t>(contiguous_scales),
301+
M
302+
);
303+
break;
304+
case 2048:
305+
device_gelu_per_token_quant_bf16_to_fp8<128, 2048>
306+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
307+
PTR<bf16_t>(contiguous_input),
308+
PTR<fp8_e4m3_t>(output),
309+
PTR<fp32_t>(contiguous_scales),
310+
M
311+
);
312+
break;
313+
case 3200:
314+
device_gelu_per_token_quant_bf16_to_fp8<128, 3200>
315+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
316+
PTR<bf16_t>(contiguous_input),
317+
PTR<fp8_e4m3_t>(output),
318+
PTR<fp32_t>(contiguous_scales),
319+
M
320+
);
321+
break;
322+
case 4096:
323+
device_gelu_per_token_quant_bf16_to_fp8<256, 4096>
324+
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
325+
PTR<bf16_t>(contiguous_input),
326+
PTR<fp8_e4m3_t>(output),
327+
PTR<fp32_t>(contiguous_scales),
328+
M
329+
);
330+
break;
331+
case 12800:
332+
device_gelu_per_token_quant_bf16_to_fp8<256, 12800>
333+
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
334+
PTR<bf16_t>(contiguous_input),
335+
PTR<fp8_e4m3_t>(output),
336+
PTR<fp32_t>(contiguous_scales),
337+
M
338+
);
339+
break;
340+
default: {
341+
static constexpr int32_t TPB = 128;
342+
int32_t sharedmem = N / 2 * sizeof(bf16x2_t);
343+
if (N % 8 == 0) {
344+
gelu_per_token_quant_bf16_to_fp8_vpt<128>
345+
<<<blocks, TPB, sharedmem, at::cuda::getCurrentCUDAStream()>>>(
346+
PTR<bf16_t>(contiguous_input),
347+
PTR<fp8_e4m3_t>(output),
348+
PTR<fp32_t>(contiguous_scales),
349+
M, N
350+
);
351+
}
352+
else {
353+
gelu_per_token_quant_bf16_to_fp8_general<128>
354+
<<<blocks, TPB, sharedmem, at::cuda::getCurrentCUDAStream()>>>(
355+
PTR<bf16_t>(contiguous_input),
356+
PTR<fp8_e4m3_t>(output),
357+
PTR<fp32_t>(contiguous_scales),
358+
M, N
359+
);
360+
}
361+
}
362+
}
363+
return ;
364+
}
365+
366+
} // namespace ops
367+
} // namespace lightllm

0 commit comments

Comments
 (0)