Skip to content

Commit

Permalink
[Kernel] Initial Activation Quantization Support (vllm-project#4525)
Browse files Browse the repository at this point in the history
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
3 people authored and dtrifiro committed May 31, 2024
1 parent e8ccf26 commit 2603b1d
Show file tree
Hide file tree
Showing 17 changed files with 683 additions and 94 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,

#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

Expand Down
3 changes: 3 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");

ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
Expand Down
59 changes: 59 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cmath>

#include "../../dispatch_utils.h"

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}

namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
});
}
31 changes: 31 additions & 0 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from vllm._C import ops

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
seed: int, scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

out1 = (x / scale).round().clamp(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
36 changes: 36 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Test model set-up and weight loading for sparseml-quantized models.
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""

import torch

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8

assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32
18 changes: 18 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,24 @@ def scaled_fp8_quant(
return output, scale


# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
Returns:
torch.Tensor: Output tensor in int8.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
Expand Down
Loading

0 comments on commit 2603b1d

Please sign in to comment.