1
+ #include " ops_common.h"
2
+ #include " reduce/sm70.cuh"
3
+
4
+
5
+ namespace lightllm {
6
+ namespace ops {
7
+
8
+ using namespace lightllm ;
9
+
10
+ template <int32_t TPB, int32_t N>
11
+ __global__ void device_gelu_per_token_quant_bf16_to_fp8 (
12
+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
13
+ fp8_e4m3_t * __restrict__ output, // Output tensor in FP8 format
14
+ fp32_t * __restrict__ scales, // Output scales for each group
15
+ const int64_t M // Number of rows in the input tensor
16
+ ) {
17
+ constexpr int32_t VPT = 8 ;
18
+
19
+ static_assert (N % 2 == 0 , " N must be even." );
20
+ static_assert (N % VPT == 0 , " N must be a multiple of VPT." );
21
+
22
+ const int32_t bid = blockIdx .x ;
23
+ const int32_t tid = threadIdx .x ;
24
+ constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
25
+ const bf16x2_t one = _float22bf162_rn (make_float2 (1 .0f , 1 .0f ));
26
+ const bf16x2_t one_2 = _float22bf162_rn (make_float2 (0 .5f , 0 .5f ));
27
+
28
+ const bf16_t * _input = input + bid * N; // Input pointer for the group
29
+ fp8_e4m3_t * _output = output + bid * N; // Output pointer for the group
30
+
31
+ fp32_t * _scales;
32
+ _scales = scales + bid;
33
+
34
+ // Local arrays for intermediate storage
35
+ fp8x4_e4m3_t local_f8[VPT / 4 ];
36
+ bf16x2_t local_bf16[VPT / 2 ];
37
+
38
+ __shared__ bf16x2_t workspace[N / 2 ];
39
+
40
+ fp32_t local_max = -FLT_MAX;
41
+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
42
+ vec_copy<sizeof (bf16_t ) * VPT>(_input + i, local_bf16);
43
+ // gelu
44
+ #pragma unroll
45
+ for (int32_t j = 0 ; j< VPT/2 ; j++){
46
+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
47
+ tmp.x = erf (tmp.x * 0 .7071067811f );
48
+ tmp.y = erf (tmp.y * 0 .7071067811f );
49
+ bf16x2_t tan = _float22bf162_rn (tmp);
50
+ tan = __hadd2 (tan, one);
51
+ tan = __hmul2 (tan, local_bf16[j]);
52
+ tan = __hmul2 (tan, one_2);
53
+ local_bf16[j] = tan;
54
+ }
55
+
56
+ vec_copy<sizeof (bf16_t ) * VPT>(local_bf16, workspace + (i >> 1 ));
57
+
58
+ #pragma unroll
59
+ for (int32_t j = 0 ; j< VPT/2 ; j++){
60
+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
61
+ fp32_t max = fmaxf (fabsf (tmp.x ), fabsf (tmp.y ));
62
+ local_max = fmaxf (local_max, max);
63
+ }
64
+ }
65
+
66
+ // Reduce the maximum value across the thread group
67
+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
68
+
69
+ // Compute the scale factor with epsilon to avoid division by zero
70
+ constexpr fp32_t epsilon = 1e-7f ;
71
+ const fp32_t scale = reduced_max / FP8_E4M3_MAX;
72
+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
73
+
74
+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
75
+ vec_copy<sizeof (bf16_t ) * VPT>(workspace + (i >> 1 ), local_bf16);
76
+
77
+ #pragma unroll
78
+ for (int32_t j = 0 ; j < VPT/4 ; j++) {
79
+ fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[2 * j + 0 ]);
80
+ fp32x2_t y = bf16x2_to_fp32x2 (local_bf16[2 * j + 1 ]);
81
+ fp32x4_t ret = make_float4 (
82
+ x.x * inv_scale,
83
+ x.y * inv_scale,
84
+ y.x * inv_scale,
85
+ y.y * inv_scale
86
+ );
87
+ local_f8[j] = fp8x4_e4m3_t (ret);
88
+ }
89
+
90
+ vec_copy<sizeof (fp8_e4m3_t ) * VPT>(local_f8, _output + i);
91
+ }
92
+
93
+ if (tid == 0 ){
94
+ *_scales = scale;
95
+ }
96
+ }
97
+
98
+
99
+ template <int32_t TPB>
100
+ __global__ void gelu_per_token_quant_bf16_to_fp8_vpt (
101
+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
102
+ fp8_e4m3_t * __restrict__ output, // Output tensor in FP8 format
103
+ fp32_t * __restrict__ scales, // Output scales for each group
104
+ const int64_t M, // Number of rows in the input tensor
105
+ const int32_t N
106
+ ) {
107
+ constexpr int32_t VPT = 8 ;
108
+
109
+ const int32_t bid = blockIdx .x ;
110
+ const int32_t tid = threadIdx .x ;
111
+ constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
112
+ constexpr fp32_t sqrt_2_over_pi = 0 .7978845608028654f ;
113
+ constexpr fp32_t coeff = 0 .044715f ;
114
+
115
+ const bf16_t * _input = input + bid * N; // Input pointer for the group
116
+ fp8_e4m3_t * _output = output + bid * N; // Output pointer for the group
117
+
118
+ fp32_t * _scales;
119
+ _scales = scales + bid;
120
+
121
+ // Local arrays for intermediate storage
122
+ fp8x4_e4m3_t local_f8[VPT / 4 ];
123
+ bf16x2_t local_bf16[VPT / 2 ];
124
+
125
+ extern __shared__ bf16x2_t workspace[];
126
+
127
+ fp32_t local_max = -FLT_MAX;
128
+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
129
+ vec_copy<sizeof (bf16_t ) * VPT>(_input + i, local_bf16);
130
+
131
+ #pragma unroll
132
+ for (int32_t j = 0 ; j< VPT/2 ; j++){
133
+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
134
+
135
+ fp32_t tanh_arg1 = sqrt_2_over_pi * (tmp.x + coeff * tmp.x * tmp.x * tmp.x );
136
+ fp32_t tanh_arg2 = sqrt_2_over_pi * (tmp.y + coeff * tmp.y * tmp.y * tmp.y );
137
+ tmp.x = 0 .5f * tmp.x * (1 .0f + tanhf (tanh_arg1));
138
+ tmp.y = 0 .5f * tmp.y * (1 .0f + tanhf (tanh_arg2));
139
+
140
+ local_bf16[j] = _float22bf162_rn (tmp);
141
+ }
142
+
143
+ vec_copy<sizeof (bf16_t ) * VPT>(local_bf16, workspace + (i >> 1 ));
144
+
145
+ // Compute the max for the VPT elements.
146
+ #pragma unroll
147
+ for (int32_t j = 0 ; j< VPT/2 ; j++){
148
+ fp32x2_t tmp = bf16x2_to_fp32x2 (local_bf16[j]);
149
+ fp32_t max = fmaxf (fabsf (tmp.x ), fabsf (tmp.y ));
150
+ local_max = fmaxf (local_max, max);
151
+ }
152
+ }
153
+
154
+ // Reduce the maximum value across the thread group
155
+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
156
+
157
+ // Compute the scale factor with epsilon to avoid division by zero
158
+ constexpr fp32_t epsilon = 1e-7f ;
159
+ const fp32_t scale = reduced_max / FP8_E4M3_MAX;
160
+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
161
+
162
+ for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
163
+ vec_copy<sizeof (bf16_t ) * VPT>(workspace + (i >> 1 ), local_bf16);
164
+
165
+ #pragma unroll
166
+ for (int32_t j = 0 ; j < VPT/4 ; j++) {
167
+ fp32x2_t x = bf16x2_to_fp32x2 (local_bf16[2 * j + 0 ]);
168
+ fp32x2_t y = bf16x2_to_fp32x2 (local_bf16[2 * j + 1 ]);
169
+ fp32x4_t ret = make_float4 (
170
+ x.x * inv_scale,
171
+ x.y * inv_scale,
172
+ y.x * inv_scale,
173
+ y.y * inv_scale
174
+ );
175
+ local_f8[j] = fp8x4_e4m3_t (ret);
176
+ }
177
+
178
+ vec_copy<sizeof (fp8_e4m3_t ) * VPT>(local_f8, _output + i);
179
+ }
180
+
181
+ if (tid == 0 ){
182
+ *_scales = scale;
183
+ }
184
+ }
185
+
186
+
187
+ template <int32_t TPB>
188
+ __global__ void gelu_per_token_quant_bf16_to_fp8_general (
189
+ const bf16_t * __restrict__ input, // Input tensor in BF16 format
190
+ fp8_e4m3_t * __restrict__ output, // Output tensor in FP8 format
191
+ fp32_t * __restrict__ scales, // Output scales for each group
192
+ const int64_t M, // Number of rows in the input tensor
193
+ const int32_t N
194
+ ) {
195
+ const int32_t bid = blockIdx .x ;
196
+ const int32_t tid = threadIdx .x ;
197
+ constexpr fp32_t FP8_E4M3_MAX = 448 .0f ; // Maximum value representable in FP8 E4M3 format
198
+ constexpr fp32_t sqrt_2_over_pi = 0 .7978845608028654f ;
199
+ constexpr fp32_t coeff = 0 .044715f ;
200
+
201
+ const bf16_t * _input = input + bid * N; // Input pointer for the group
202
+ fp8_e4m3_t * _output = output + bid * N; // Output pointer for the group
203
+
204
+ fp32_t * _scales;
205
+ _scales = scales + bid;
206
+
207
+ extern __shared__ bf16_t workspace_[];
208
+
209
+ fp32_t local_max = -FLT_MAX;
210
+
211
+ for (int32_t i = tid; i < N; i += TPB) {
212
+ fp32_t tmp = cvt_bf16_f32 (_input[i]);
213
+ fp32_t tanh_arg = sqrt_2_over_pi * (tmp + coeff * tmp * tmp * tmp);
214
+ tmp = 0 .5f * tmp * (1 .0f + tanhf (tanh_arg));
215
+ local_max = fmaxf (local_max, fabsf (tmp));
216
+ workspace_[i] = cvt_f32_bf16 (tmp);
217
+ }
218
+
219
+ // Reduce the maximum value across the thread group
220
+ const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
221
+
222
+ // Compute the scale factor with epsilon to avoid division by zero
223
+ constexpr fp32_t epsilon = 1e-7f ;
224
+ const fp32_t scale = reduced_max / FP8_E4M3_MAX;
225
+ const fp32_t inv_scale = 1 .0f / (scale + epsilon);
226
+
227
+ for (int32_t i = tid; i < N; i += TPB) {
228
+ // Load the previously stored vectorized data from shared memory.
229
+ fp32_t x = cvt_bf16_f32 (workspace_[i]);
230
+ // Apply normalization: multiply by inv_norm and then scale by the weight.
231
+ fp32_t ret = x * inv_scale;
232
+ _output[i] = fp8_e4m3_t (ret);
233
+ }
234
+
235
+ if (tid == 0 ){
236
+ *_scales = scale;
237
+ }
238
+ }
239
+
240
+ void gelu_per_token_quant_bf16_fp8 (
241
+ Tensor& output,
242
+ const Tensor& input,
243
+ Tensor& scales
244
+ ) {
245
+ TORCH_CHECK (input.is_cuda (), " Input must be a CUDA tensor" );
246
+ TORCH_CHECK (input.dim () == 2 , " Input must be 2-dimensional" );
247
+ TORCH_CHECK (input.scalar_type () == c10::kBFloat16 , " Input must be BF16 type" );
248
+
249
+ Tensor contiguous_input = input.is_contiguous () ? input : input.contiguous ();
250
+ Tensor contiguous_scales = scales.is_contiguous () ? scales : scales.contiguous ();
251
+
252
+ const int64_t M = input.size (0 );
253
+ const int64_t N = input.size (1 );
254
+
255
+ const int32_t blocks = M;
256
+
257
+ switch (N) {
258
+ case 16 :
259
+ device_gelu_per_token_quant_bf16_to_fp8<64 , 16 >
260
+ <<<blocks, 64 , 0 , at::cuda::getCurrentCUDAStream()>>> (
261
+ PTR<bf16_t >(contiguous_input),
262
+ PTR<fp8_e4m3_t >(output),
263
+ PTR<fp32_t >(contiguous_scales),
264
+ M
265
+ );
266
+ break ;
267
+ case 32 :
268
+ device_gelu_per_token_quant_bf16_to_fp8<64 , 32 >
269
+ <<<blocks, 64 , 0 , at::cuda::getCurrentCUDAStream()>>> (
270
+ PTR<bf16_t >(contiguous_input),
271
+ PTR<fp8_e4m3_t >(output),
272
+ PTR<fp32_t >(contiguous_scales),
273
+ M
274
+ );
275
+ break ;
276
+ case 64 :
277
+ device_gelu_per_token_quant_bf16_to_fp8<64 , 64 >
278
+ <<<blocks, 64 , 0 , at::cuda::getCurrentCUDAStream()>>> (
279
+ PTR<bf16_t >(contiguous_input),
280
+ PTR<fp8_e4m3_t >(output),
281
+ PTR<fp32_t >(contiguous_scales),
282
+ M
283
+ );
284
+ break ;
285
+ case 512 :
286
+ device_gelu_per_token_quant_bf16_to_fp8<64 , 512 >
287
+ <<<blocks, 64 , 0 , at::cuda::getCurrentCUDAStream()>>> (
288
+ PTR<bf16_t >(contiguous_input),
289
+ PTR<fp8_e4m3_t >(output),
290
+ PTR<fp32_t >(contiguous_scales),
291
+ M
292
+ );
293
+ break ;
294
+
295
+ case 1024 :
296
+ device_gelu_per_token_quant_bf16_to_fp8<128 , 1024 >
297
+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
298
+ PTR<bf16_t >(contiguous_input),
299
+ PTR<fp8_e4m3_t >(output),
300
+ PTR<fp32_t >(contiguous_scales),
301
+ M
302
+ );
303
+ break ;
304
+ case 2048 :
305
+ device_gelu_per_token_quant_bf16_to_fp8<128 , 2048 >
306
+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
307
+ PTR<bf16_t >(contiguous_input),
308
+ PTR<fp8_e4m3_t >(output),
309
+ PTR<fp32_t >(contiguous_scales),
310
+ M
311
+ );
312
+ break ;
313
+ case 3200 :
314
+ device_gelu_per_token_quant_bf16_to_fp8<128 , 3200 >
315
+ <<<blocks, 128 , 0 , at::cuda::getCurrentCUDAStream()>>> (
316
+ PTR<bf16_t >(contiguous_input),
317
+ PTR<fp8_e4m3_t >(output),
318
+ PTR<fp32_t >(contiguous_scales),
319
+ M
320
+ );
321
+ break ;
322
+ case 4096 :
323
+ device_gelu_per_token_quant_bf16_to_fp8<256 , 4096 >
324
+ <<<blocks, 256 , 0 , at::cuda::getCurrentCUDAStream()>>> (
325
+ PTR<bf16_t >(contiguous_input),
326
+ PTR<fp8_e4m3_t >(output),
327
+ PTR<fp32_t >(contiguous_scales),
328
+ M
329
+ );
330
+ break ;
331
+ case 12800 :
332
+ device_gelu_per_token_quant_bf16_to_fp8<256 , 12800 >
333
+ <<<blocks, 256 , 0 , at::cuda::getCurrentCUDAStream()>>> (
334
+ PTR<bf16_t >(contiguous_input),
335
+ PTR<fp8_e4m3_t >(output),
336
+ PTR<fp32_t >(contiguous_scales),
337
+ M
338
+ );
339
+ break ;
340
+ default : {
341
+ static constexpr int32_t TPB = 128 ;
342
+ int32_t sharedmem = N / 2 * sizeof (bf16x2_t );
343
+ if (N % 8 == 0 ) {
344
+ gelu_per_token_quant_bf16_to_fp8_vpt<128 >
345
+ <<<blocks, TPB, sharedmem, at::cuda::getCurrentCUDAStream()>>> (
346
+ PTR<bf16_t >(contiguous_input),
347
+ PTR<fp8_e4m3_t >(output),
348
+ PTR<fp32_t >(contiguous_scales),
349
+ M, N
350
+ );
351
+ }
352
+ else {
353
+ gelu_per_token_quant_bf16_to_fp8_general<128 >
354
+ <<<blocks, TPB, sharedmem, at::cuda::getCurrentCUDAStream()>>> (
355
+ PTR<bf16_t >(contiguous_input),
356
+ PTR<fp8_e4m3_t >(output),
357
+ PTR<fp32_t >(contiguous_scales),
358
+ M, N
359
+ );
360
+ }
361
+ }
362
+ }
363
+ return ;
364
+ }
365
+
366
+ } // namespace ops
367
+ } // namespace lightllm
0 commit comments