-
Notifications
You must be signed in to change notification settings - Fork 264
add per_token_quant_bf16_int8 kernel #939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
SangChengC
merged 1 commit into
ModelTC:add-lightllm-kernel
from
theNiemand:feature/add-int8-quant-kernel
Jun 24, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
338 changes: 338 additions & 0 deletions
338
lightllm-kernel/csrc/quant/per_token_quantize_bf16_int8.cu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
#include "ops_common.h" | ||
#include "reduce/sm70.cuh" | ||
|
||
|
||
namespace lightllm { | ||
namespace ops { | ||
|
||
using namespace lightllm; | ||
|
||
// CUDA kernel for per token quantization from BF16 to INT8 | ||
template<int32_t TPB> | ||
__global__ void device_per_token_quant_bf16_to_int8_general( | ||
const bf16_t* __restrict__ input, // Input tensor in BF16 format | ||
int8_t* __restrict__ output, // Output tensor in INT8 format | ||
fp32_t* __restrict__ scales, // Output scales for each token | ||
const int64_t M, // Number of rows in the input tensor | ||
const int64_t N | ||
) { | ||
const int32_t bid = blockIdx.x; | ||
const int32_t tid = threadIdx.x; | ||
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format | ||
|
||
const bf16_t* _input = input + bid * N; // Input pointer for the token | ||
int8_t* _output = output + bid * N; // Output pointer for the token | ||
|
||
fp32_t* _scales; | ||
_scales = scales + bid; | ||
|
||
// Local arrays for intermediate storage | ||
int8_t local_int8; | ||
bf16_t local_bf16; | ||
|
||
extern __shared__ bf16_t workspace1[]; | ||
|
||
fp32_t local_max = -FLT_MAX; | ||
for (int32_t i = tid; i < N; i += TPB) { | ||
local_bf16 = _input[i]; | ||
workspace1[i] = local_bf16; | ||
|
||
fp32_t tmp = cvt_bf16_f32(local_bf16); | ||
local_max = fmaxf(local_max, tmp); | ||
} | ||
|
||
// Reduce the maximum value across the block | ||
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); | ||
|
||
// Compute the scale factor with epsilon to avoid division by zero | ||
constexpr fp32_t epsilon = 1e-7f; | ||
const fp32_t scale = reduced_max / kINT8Max; | ||
const fp32_t inv_scale = 1.0f / (scale + epsilon); | ||
|
||
for (int32_t i = tid; i < N; i += TPB) { | ||
local_bf16 = workspace1[i]; | ||
|
||
fp32_t tmp = cvt_bf16_f32(local_bf16); | ||
fp32_t x = tmp * inv_scale; | ||
local_int8 = float_to_int8_rn(x); | ||
|
||
_output[i] = local_int8; | ||
} | ||
|
||
if(tid == 0){ | ||
*_scales = scale; | ||
} | ||
|
||
} | ||
|
||
// CUDA kernel for per token quantization from BF16 to INT8 | ||
template<int32_t TPB> | ||
__global__ void device_per_token_quant_bf16_to_int8_vpt( | ||
const bf16_t* __restrict__ input, // Input tensor in BF16 format | ||
int8_t* __restrict__ output, // Output tensor in INT8 format | ||
fp32_t* __restrict__ scales, // Output scales for each token | ||
const int64_t M, // Number of rows in the input tensor | ||
const int32_t N | ||
) { | ||
constexpr int32_t VPT = 8; | ||
|
||
const int32_t bid = blockIdx.x; | ||
const int32_t tid = threadIdx.x; | ||
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format | ||
|
||
const bf16_t* _input = input + bid * N; // Input pointer for the token | ||
int8_t* _output = output + bid * N; // Output pointer for the token | ||
|
||
fp32_t* _scales; | ||
_scales = scales + bid; | ||
|
||
// Local arrays for intermediate storage | ||
int8_t local_int8[VPT]; | ||
bf16x2_t local_bf16[VPT / 2]; | ||
|
||
extern __shared__ bf16x2_t workspace2[]; | ||
|
||
fp32_t local_max = -FLT_MAX; | ||
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { | ||
// Load VPT FP16 elements from global memory (_X) into local vector (local_x). | ||
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16); | ||
|
||
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace2 + (i >> 1)); | ||
|
||
// Compute the max for the VPT elements. | ||
#pragma unroll | ||
for(int32_t j = 0; j< VPT/2; j++){ | ||
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); | ||
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); | ||
local_max = fmaxf(local_max, max); | ||
} | ||
} | ||
|
||
// Reduce the maximum value across the block | ||
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); | ||
|
||
// Compute the scale factor with epsilon to avoid division by zero | ||
constexpr fp32_t epsilon = 1e-7f; | ||
const fp32_t scale = reduced_max / kINT8Max; | ||
const fp32_t inv_scale = 1.0f / (scale + epsilon); | ||
|
||
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { | ||
vec_copy<sizeof(bf16_t) * VPT>(workspace2 + (i >> 1), local_bf16); | ||
|
||
#pragma unroll | ||
for (int32_t j = 0; j < VPT/2; j++) { | ||
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]); | ||
|
||
int8_t a = float_to_int8_rn(x.x * inv_scale); | ||
int8_t b = float_to_int8_rn(x.y * inv_scale); | ||
|
||
local_int8[2 * j] = a; | ||
local_int8[2 * j + 1] = b; | ||
} | ||
|
||
vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i); | ||
} | ||
|
||
if(tid == 0){ | ||
*_scales = scale; | ||
} | ||
} | ||
|
||
|
||
|
||
// CUDA kernel for per token quantization from BF16 to INT8 | ||
template<int32_t TPB, int32_t N> | ||
__global__ void device_per_token_quant_bf16_to_int8( | ||
const bf16_t* __restrict__ input, // Input tensor in BF16 format | ||
int8_t* __restrict__ output, // Output tensor in INT8 format | ||
fp32_t* __restrict__ scales, // Output scales for each token | ||
const int64_t M // Number of rows in the input tensor | ||
) { | ||
constexpr int32_t VPT = 8; | ||
|
||
static_assert(N % 2 == 0, "N must be even."); | ||
static_assert(N % VPT == 0, "N must be a multiple of VPT."); | ||
|
||
const int32_t bid = blockIdx.x; | ||
const int32_t tid = threadIdx.x; | ||
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format | ||
|
||
const bf16_t* _input = input + bid * N; // Input pointer for the token | ||
int8_t* _output = output + bid * N; // Output pointer for the token | ||
|
||
fp32_t* _scales; | ||
_scales = scales + bid; | ||
|
||
// Local arrays for intermediate storage | ||
int8_t local_int8[VPT]; | ||
bf16x2_t local_bf16[VPT / 2]; | ||
|
||
__shared__ bf16x2_t workspace[N / 2]; | ||
|
||
fp32_t local_max = -FLT_MAX; | ||
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { | ||
// Load VPT FP16 elements from global memory (_X) into local vector (local_x). | ||
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16); | ||
|
||
vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1)); | ||
|
||
// Compute the max for the VPT elements. | ||
#pragma unroll | ||
for(int32_t j = 0; j< VPT/2; j++){ | ||
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]); | ||
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); | ||
local_max = fmaxf(local_max, max); | ||
} | ||
} | ||
|
||
// Reduce the maximum value across the block | ||
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); | ||
|
||
// Compute the scale factor with epsilon to avoid division by zero | ||
constexpr fp32_t epsilon = 1e-7f; | ||
const fp32_t scale = reduced_max / kINT8Max; | ||
const fp32_t inv_scale = 1.0f / (scale + epsilon); | ||
|
||
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { | ||
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16); | ||
|
||
#pragma unroll | ||
for (int32_t j = 0; j < VPT/2; j++) { | ||
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]); | ||
|
||
int8_t a = float_to_int8_rn(x.x * inv_scale); | ||
int8_t b = float_to_int8_rn(x.y * inv_scale); | ||
|
||
local_int8[2 * j] = a; | ||
local_int8[2 * j + 1] = b; | ||
} | ||
|
||
vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i); | ||
} | ||
|
||
if(tid == 0){ | ||
*_scales = scale; | ||
} | ||
} | ||
|
||
|
||
void per_token_quant_bf16_int8 ( | ||
Tensor& output, | ||
const Tensor& input, | ||
Tensor& scales | ||
) { | ||
TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); | ||
TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional"); | ||
TORCH_CHECK(input.scalar_type() == c10::kBFloat16, "Input must be BF16 type"); | ||
|
||
Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous(); | ||
Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous(); | ||
|
||
const int64_t M = input.size(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
const int64_t N = input.size(1); | ||
|
||
const int32_t blocks = M; | ||
|
||
switch (N) { | ||
case 16: | ||
device_per_token_quant_bf16_to_int8<128, 16> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 32: | ||
device_per_token_quant_bf16_to_int8<128, 32> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 64: | ||
device_per_token_quant_bf16_to_int8<128, 64> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 512: | ||
device_per_token_quant_bf16_to_int8<128, 512> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 1024: | ||
device_per_token_quant_bf16_to_int8<128, 1024> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 3200: | ||
device_per_token_quant_bf16_to_int8<128, 3200> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 4096: | ||
device_per_token_quant_bf16_to_int8<128, 4096> | ||
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
case 12800: | ||
device_per_token_quant_bf16_to_int8<256, 12800> | ||
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M | ||
); | ||
break; | ||
default: { | ||
static constexpr int TPB = 128; | ||
const int64_t shared_mem_size = N * sizeof(bf16_t); | ||
if (N % 8 == 0) { | ||
device_per_token_quant_bf16_to_int8_vpt<TPB> | ||
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M, | ||
N | ||
); | ||
} else { | ||
device_per_token_quant_bf16_to_int8_general<TPB> | ||
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>( | ||
PTR<bf16_t>(contiguous_input), | ||
PTR<int8_t>(output), | ||
PTR<fp32_t>(contiguous_scales), | ||
M, | ||
N | ||
); | ||
} | ||
} | ||
} | ||
|
||
return; | ||
} | ||
|
||
} // namespace ops | ||
} // namespace lightllm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maximum value is calculated without considering the absolute values. For symmetric quantization, find the maximum of the absolute values to ensure correct scaling for tensors with negative values.