forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] Initial Activation Quantization Support (vllm-project#4525)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
- Loading branch information
Showing
17 changed files
with
683 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
#include <cmath> | ||
|
||
#include "../../dispatch_utils.h" | ||
|
||
static inline __device__ int8_t float_to_int8_rn(float x) { | ||
#ifdef USE_ROCM | ||
static const float i8_min = | ||
static_cast<float>(std::numeric_limits<int8_t>::min()); | ||
static const float i8_max = | ||
static_cast<float>(std::numeric_limits<int8_t>::max()); | ||
// round | ||
float dst = std::nearbyint(x); | ||
// saturate | ||
dst = std::clamp(dst, i8_min, i8_max); | ||
return static_cast<int8_t>(dst); | ||
#else | ||
// CUDA path | ||
uint32_t dst; | ||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); | ||
return reinterpret_cast<const int8_t&>(dst); | ||
#endif | ||
} | ||
|
||
namespace vllm { | ||
|
||
template <typename scalar_t, typename scale_type> | ||
__global__ void static_scaled_int8_quant_kernel( | ||
const scalar_t* __restrict__ input, int8_t* __restrict__ out, | ||
scale_type scale, const int hidden_size) { | ||
const int tid = threadIdx.x; | ||
const int token_idx = blockIdx.x; | ||
|
||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
out[token_idx * hidden_size + i] = | ||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); | ||
} | ||
} | ||
} // namespace vllm | ||
|
||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] | ||
torch::Tensor& input, // [..., hidden_size] | ||
float scale) { | ||
TORCH_CHECK(input.is_contiguous()); | ||
TORCH_CHECK(out.is_contiguous()); | ||
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
dim3 grid(num_tokens); | ||
dim3 block(std::min(hidden_size, 1024)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { | ||
vllm::static_scaled_int8_quant_kernel<scalar_t, float> | ||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), | ||
out.data_ptr<int8_t>(), scale, | ||
hidden_size); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
import torch | ||
|
||
from vllm._C import ops | ||
|
||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing | ||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing | ||
SEEDS = [0] | ||
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("scale", SCALE) | ||
@torch.inference_mode() | ||
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, | ||
seed: int, scale: float) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 | ||
|
||
out1 = (x / scale).round().clamp( | ||
torch.iinfo(torch.int8).min, | ||
torch.iinfo(torch.int8).max).to(torch.int8) | ||
out2 = torch.empty_like(x, dtype=torch.int8) | ||
ops.static_scaled_int8_quant(out2, x, scale) | ||
assert torch.allclose(out1, out2, | ||
atol=1) # big atol to account for rounding errors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Test model set-up and weight loading for sparseml-quantized models. | ||
Run `pytest tests/quantization/test_compressed_tensors.py`. | ||
""" | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 | ||
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) | ||
|
||
|
||
def test_compressed_tensors_w8a8_static_setup(vllm_runner): | ||
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" | ||
llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) | ||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model | ||
layer = model.model.layers[0] | ||
|
||
qkv_proj = layer.self_attn.qkv_proj | ||
o_proj = layer.self_attn.o_proj | ||
gate_up_proj = layer.mlp.gate_up_proj | ||
down_proj = layer.mlp.down_proj | ||
|
||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) | ||
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) | ||
assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) | ||
assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) | ||
|
||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) | ||
|
||
assert qkv_proj.weight.dtype is torch.int8 | ||
assert o_proj.weight.dtype is torch.int8 | ||
assert gate_up_proj.weight.dtype is torch.int8 | ||
|
||
assert qkv_proj.weight_scale.shard_splitter is not None | ||
assert qkv_proj.weight_scale.logical_widths is not None | ||
assert qkv_proj.input_scale.dtype is torch.float32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.