Skip to content

Commit

Permalink
[Kernel] Pass a device pointer into the quantize kernel for the scales (
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored and blinkbear committed Jun 6, 2024
1 parent 625b4f9 commit dc60e72
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,

#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
Expand Down
15 changes: 9 additions & 6 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
Expand All @@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel(
}
} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
Expand All @@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
});
}
4 changes: 3 additions & 1 deletion tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")

ops.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def scaled_fp8_quant(

# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
scale: torch.Tensor) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
act_scale = layer.input_scale

# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)

return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)

0 comments on commit dc60e72

Please sign in to comment.