Skip to content

add per_token_quant_bf16_int8 kernel #939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lightllm-kernel/csrc/ops_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ PYBIND11_MODULE(_C, m) {
m.def("rmsnorm_align16_bf16", &rmsnorm_align16_bf16, "RMSNORM (CUDA)");
m.def("pre_tp_norm_bf16", &pre_tp_norm_bf16, "PRE TP NORM (CUDA)");
m.def("post_tp_norm_bf16", &post_tp_norm_bf16, "POST TP NORM (CUDA)");
m.def("per_token_quant_bf16_fp8", &per_token_quant_bf16_fp8, "PER TOKEN QUANT (CUDA)");
m.def("per_token_quant_bf16_fp8", &per_token_quant_bf16_fp8, "PER TOKEN QUANT FP8 (CUDA)");
m.def("per_token_quant_bf16_int8", &per_token_quant_bf16_int8, "PER TOKEN QUANT INT8 (CUDA)");
m.def("add_norm_quant_bf16_fp8", &add_norm_quant_bf16_fp8, "ADD NORM QUANT FUSED (CUDA)");
m.def("gelu_per_token_quant_bf16_fp8", &gelu_per_token_quant_bf16_fp8, "GELU QUANT FUSED (CUDA)");
m.def("cutlass_scaled_mm", &cutlass_scaled_mm, "CUTLASS SCALED MM (CUDA)");
Expand Down
338 changes: 338 additions & 0 deletions lightllm-kernel/csrc/quant/per_token_quantize_bf16_int8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
#include "ops_common.h"
#include "reduce/sm70.cuh"


namespace lightllm {
namespace ops {

using namespace lightllm;

// CUDA kernel for per token quantization from BF16 to INT8
template<int32_t TPB>
__global__ void device_per_token_quant_bf16_to_int8_general(
const bf16_t* __restrict__ input, // Input tensor in BF16 format
int8_t* __restrict__ output, // Output tensor in INT8 format
fp32_t* __restrict__ scales, // Output scales for each token
const int64_t M, // Number of rows in the input tensor
const int64_t N
) {
const int32_t bid = blockIdx.x;
const int32_t tid = threadIdx.x;
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format

const bf16_t* _input = input + bid * N; // Input pointer for the token
int8_t* _output = output + bid * N; // Output pointer for the token

fp32_t* _scales;
_scales = scales + bid;

// Local arrays for intermediate storage
int8_t local_int8;
bf16_t local_bf16;

extern __shared__ bf16_t workspace1[];

fp32_t local_max = -FLT_MAX;
for (int32_t i = tid; i < N; i += TPB) {
local_bf16 = _input[i];
workspace1[i] = local_bf16;

fp32_t tmp = cvt_bf16_f32(local_bf16);
local_max = fmaxf(local_max, tmp);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The maximum value is calculated without considering the absolute values. For symmetric quantization, find the maximum of the absolute values to ensure correct scaling for tensors with negative values.

        local_max = fmaxf(local_max, fabsf(tmp));

}

// Reduce the maximum value across the block
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);

// Compute the scale factor with epsilon to avoid division by zero
constexpr fp32_t epsilon = 1e-7f;
const fp32_t scale = reduced_max / kINT8Max;
const fp32_t inv_scale = 1.0f / (scale + epsilon);

for (int32_t i = tid; i < N; i += TPB) {
local_bf16 = workspace1[i];

fp32_t tmp = cvt_bf16_f32(local_bf16);
fp32_t x = tmp * inv_scale;
local_int8 = float_to_int8_rn(x);

_output[i] = local_int8;
}

if(tid == 0){
*_scales = scale;
}

}

// CUDA kernel for per token quantization from BF16 to INT8
template<int32_t TPB>
__global__ void device_per_token_quant_bf16_to_int8_vpt(
const bf16_t* __restrict__ input, // Input tensor in BF16 format
int8_t* __restrict__ output, // Output tensor in INT8 format
fp32_t* __restrict__ scales, // Output scales for each token
const int64_t M, // Number of rows in the input tensor
const int32_t N
) {
constexpr int32_t VPT = 8;

const int32_t bid = blockIdx.x;
const int32_t tid = threadIdx.x;
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format

const bf16_t* _input = input + bid * N; // Input pointer for the token
int8_t* _output = output + bid * N; // Output pointer for the token

fp32_t* _scales;
_scales = scales + bid;

// Local arrays for intermediate storage
int8_t local_int8[VPT];
bf16x2_t local_bf16[VPT / 2];

extern __shared__ bf16x2_t workspace2[];

fp32_t local_max = -FLT_MAX;
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);

vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace2 + (i >> 1));

// Compute the max for the VPT elements.
#pragma unroll
for(int32_t j = 0; j< VPT/2; j++){
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
local_max = fmaxf(local_max, max);
}
}

// Reduce the maximum value across the block
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);

// Compute the scale factor with epsilon to avoid division by zero
constexpr fp32_t epsilon = 1e-7f;
const fp32_t scale = reduced_max / kINT8Max;
const fp32_t inv_scale = 1.0f / (scale + epsilon);

for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
vec_copy<sizeof(bf16_t) * VPT>(workspace2 + (i >> 1), local_bf16);

#pragma unroll
for (int32_t j = 0; j < VPT/2; j++) {
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]);

int8_t a = float_to_int8_rn(x.x * inv_scale);
int8_t b = float_to_int8_rn(x.y * inv_scale);

local_int8[2 * j] = a;
local_int8[2 * j + 1] = b;
}

vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i);
}

if(tid == 0){
*_scales = scale;
}
}



// CUDA kernel for per token quantization from BF16 to INT8
template<int32_t TPB, int32_t N>
__global__ void device_per_token_quant_bf16_to_int8(
const bf16_t* __restrict__ input, // Input tensor in BF16 format
int8_t* __restrict__ output, // Output tensor in INT8 format
fp32_t* __restrict__ scales, // Output scales for each token
const int64_t M // Number of rows in the input tensor
) {
constexpr int32_t VPT = 8;

static_assert(N % 2 == 0, "N must be even.");
static_assert(N % VPT == 0, "N must be a multiple of VPT.");

const int32_t bid = blockIdx.x;
const int32_t tid = threadIdx.x;
constexpr fp32_t kINT8Max = 127.0f; // Maximum value representable in INT8 format

const bf16_t* _input = input + bid * N; // Input pointer for the token
int8_t* _output = output + bid * N; // Output pointer for the token

fp32_t* _scales;
_scales = scales + bid;

// Local arrays for intermediate storage
int8_t local_int8[VPT];
bf16x2_t local_bf16[VPT / 2];

__shared__ bf16x2_t workspace[N / 2];

fp32_t local_max = -FLT_MAX;
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
vec_copy<sizeof(bf16_t) * VPT>(_input + i, local_bf16);

vec_copy<sizeof(bf16_t) * VPT>(local_bf16, workspace + (i >> 1));

// Compute the max for the VPT elements.
#pragma unroll
for(int32_t j = 0; j< VPT/2; j++){
fp32x2_t tmp = bf16x2_to_fp32x2(local_bf16[j]);
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
local_max = fmaxf(local_max, max);
}
}

// Reduce the maximum value across the block
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);

// Compute the scale factor with epsilon to avoid division by zero
constexpr fp32_t epsilon = 1e-7f;
const fp32_t scale = reduced_max / kINT8Max;
const fp32_t inv_scale = 1.0f / (scale + epsilon);

for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
vec_copy<sizeof(bf16_t) * VPT>(workspace + (i >> 1), local_bf16);

#pragma unroll
for (int32_t j = 0; j < VPT/2; j++) {
fp32x2_t x = bf16x2_to_fp32x2(local_bf16[j]);

int8_t a = float_to_int8_rn(x.x * inv_scale);
int8_t b = float_to_int8_rn(x.y * inv_scale);

local_int8[2 * j] = a;
local_int8[2 * j + 1] = b;
}

vec_copy<sizeof(int8_t) * VPT>(local_int8, _output + i);
}

if(tid == 0){
*_scales = scale;
}
}


void per_token_quant_bf16_int8 (
Tensor& output,
const Tensor& input,
Tensor& scales
) {
TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(input.scalar_type() == c10::kBFloat16, "Input must be BF16 type");

Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous();
Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous();

const int64_t M = input.size(0);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The M variable is passed to the CUDA kernels but is unused. Remove the M parameter from the kernel definitions and call sites to avoid confusion.

const int64_t N = input.size(1);

const int32_t blocks = M;

switch (N) {
case 16:
device_per_token_quant_bf16_to_int8<128, 16>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 32:
device_per_token_quant_bf16_to_int8<128, 32>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 64:
device_per_token_quant_bf16_to_int8<128, 64>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 512:
device_per_token_quant_bf16_to_int8<128, 512>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 1024:
device_per_token_quant_bf16_to_int8<128, 1024>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 3200:
device_per_token_quant_bf16_to_int8<128, 3200>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 4096:
device_per_token_quant_bf16_to_int8<128, 4096>
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
case 12800:
device_per_token_quant_bf16_to_int8<256, 12800>
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M
);
break;
default: {
static constexpr int TPB = 128;
const int64_t shared_mem_size = N * sizeof(bf16_t);
if (N % 8 == 0) {
device_per_token_quant_bf16_to_int8_vpt<TPB>
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M,
N
);
} else {
device_per_token_quant_bf16_to_int8_general<TPB>
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
PTR<bf16_t>(contiguous_input),
PTR<int8_t>(output),
PTR<fp32_t>(contiguous_scales),
M,
N
);
}
}
}

return;
}

} // namespace ops
} // namespace lightllm
6 changes: 6 additions & 0 deletions lightllm-kernel/include/ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ void per_token_quant_bf16_fp8(
Tensor& scales
);

void per_token_quant_bf16_int8(
Tensor& output,
const Tensor& input,
Tensor& scales
);

std::tuple<Tensor, Tensor> add_norm_quant_bf16_fp8(
Tensor& X, const Tensor &R, const Tensor &W,
const fp32_t eps
Expand Down
6 changes: 6 additions & 0 deletions lightllm-kernel/include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ __device__ inline bf16x2_t _float22bf162_rn(fp32x2_t val) {
return bf16x2_t(low, high);
}

__device__ inline int8_t float_to_int8_rn(fp32_t x) {
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

template <typename T>
__host__ __device__ T Cdiv(T numerator, T denominator) {
return (numerator + denominator - 1) / denominator;
Expand Down
3 changes: 2 additions & 1 deletion lightllm-kernel/lightllm_kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@
allgather_register_graph_buffers,
allgather_get_graph_buffer_ipc_meta,
)
from .quant import per_token_quant_bf16_fp8
from .quant import per_token_quant_bf16_fp8, per_token_quant_bf16_int8
from .gemm import cutlass_scaled_mm_bias_ls
from .moe import grouped_topk
from .attention import group8_int8kv_flashdecoding_stage1, group_int8kv_decode_attention

__all__ = [
"rmsnorm_bf16",
"per_token_quant_bf16_fp8",
"per_token_quant_bf16_int8",
"pre_tp_norm_bf16",
"post_tp_norm_bf16",
"add_norm_quant_bf16_fp8",
Expand Down
Loading