Skip to content

Commit

Permalink
[Kernel] AQ AZP 3/4: Asymmetric quantization kernels (vllm-project#7270)
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg authored Sep 16, 2024
1 parent 781e3b9 commit 5d73ae4
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 57 deletions.
9 changes: 6 additions & 3 deletions csrc/cpu/quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
// static-per-tensor quantization.
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
const torch::Tensor& scale) {
const torch::Tensor& scale,
c10::optional<torch::Tensor> const& azp) {
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");

const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
Expand All @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale // [..., 1]
) {
torch::Tensor& scale, // [..., 1]
c10::optional<torch::Tensor> const& azp) {
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
Expand Down
9 changes: 5 additions & 4 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#ifdef __AVX512F__
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);

// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
Expand Down
6 changes: 4 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp);

void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);
torch::Tensor& scales,
c10::optional<torch::Tensor> const& azp);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
Expand Down
173 changes: 160 additions & 13 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static constexpr auto i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static constexpr auto i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round

// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x);

// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
Expand All @@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif
}

static inline __device__ int32_t float_to_int32_rn(float x) {
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
static constexpr auto i32_min_f = static_cast<float>(i32_min);
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
static constexpr auto i32_max_f = static_cast<float>(i32_max);

// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float dst = std::nearbyint(x);

// saturate on the higher end.
if (dst >= i32_max_f) {
return i32_max;
}
// saturate on the lower end.
if (dst <= i32_min_f) {
return i32_min;
}

return static_cast<int32_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int32_t&>(dst);
#endif
}

static inline __device__ int8_t int32_to_int8(int32_t x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
static constexpr auto i8_max =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());

// saturate
int32_t dst = std::clamp(x, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}

namespace vllm {

template <typename scalar_t, typename scale_type>
Expand All @@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
}
}

template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
Expand Down Expand Up @@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;

// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}

// Reduce the max and min values across the block
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
__syncthreads(); // Make sure min doesn't mess with max shared memory
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);

__shared__ scale_type scale_sh;
__shared__ azp_type azp_sh;

// Compute the scale and zero point and store them, only on the first thread
if (threadIdx.x == 0) {
float const scale_val = (max_val - min_val) / 255.0f;
// Use rounding to even (same as torch.round)
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
auto const azp_val = static_cast<azp_type>(azp_float);

// Store the scale and azp into shared and global
scale[token_idx] = scale_sh = scale_val;
azp[token_idx] = azp_sh = azp_val;
}

// Wait for the scale and azp to be computed
__syncthreads();

float const scale_val = scale_sh;
azp_type const azp_val = azp_sh;

// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
}
}

} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp || azp->numel() == 1);

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
Expand All @@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
if (!azp) {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
} else {
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}

void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(!azp || azp->is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
Expand All @@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
if (!azp) {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
} else {
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}
8 changes: 4 additions & 4 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
}
Expand Down
Loading

0 comments on commit 5d73ae4

Please sign in to comment.