Skip to content

Commit cbb2f59

Browse files
authored
[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
1 parent 0ab278c commit cbb2f59

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

csrc/ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
9494

9595
#endif
9696

97-
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
98-
float scale);
97+
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
98+
torch::Tensor const& scale);
9999

100100
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
101101
torch::Tensor lookup_table);

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ namespace vllm {
2828
template <typename scalar_t, typename scale_type>
2929
__global__ void static_scaled_int8_quant_kernel(
3030
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
31-
scale_type scale, const int hidden_size) {
31+
const scale_type* scale_ptr, const int hidden_size) {
3232
const int tid = threadIdx.x;
3333
const int token_idx = blockIdx.x;
34+
scale_type scale = *scale_ptr;
3435

3536
for (int i = tid; i < hidden_size; i += blockDim.x) {
3637
out[token_idx * hidden_size + i] =
@@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel(
3940
}
4041
} // namespace vllm
4142

42-
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
43-
torch::Tensor& input, // [..., hidden_size]
44-
float scale) {
43+
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
44+
torch::Tensor const& input, // [..., hidden_size]
45+
torch::Tensor const& scale) {
4546
TORCH_CHECK(input.is_contiguous());
4647
TORCH_CHECK(out.is_contiguous());
48+
TORCH_CHECK(scale.numel() == 1);
49+
4750
int hidden_size = input.size(-1);
4851
int num_tokens = input.numel() / hidden_size;
4952
dim3 grid(num_tokens);
@@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
5356
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
5457
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
5558
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
56-
out.data_ptr<int8_t>(), scale,
57-
hidden_size);
59+
out.data_ptr<int8_t>(),
60+
scale.data_ptr<float>(), hidden_size);
5861
});
5962
}

tests/kernels/test_int8_quant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
2626
torch.iinfo(torch.int8).min,
2727
torch.iinfo(torch.int8).max).to(torch.int8)
2828
out2 = torch.empty_like(x, dtype=torch.int8)
29-
ops.static_scaled_int8_quant(out2, x, scale)
29+
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
30+
31+
ops.static_scaled_int8_quant(out2, x, scale_argument)
3032
assert torch.allclose(out1, out2,
3133
atol=1) # big atol to account for rounding errors

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def scaled_fp8_quant(
265265

266266
# int8
267267
def static_scaled_int8_quant(input: torch.Tensor,
268-
scale: float) -> torch.Tensor:
268+
scale: torch.Tensor) -> torch.Tensor:
269269
"""
270270
Quantize the input tensor to int8 and return the quantized tensor.
271271

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
9797
act_scale = layer.input_scale
9898

9999
# Input quantize
100-
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
100+
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
101101

102102
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
103103
weight_scale, x.dtype)

0 commit comments

Comments
 (0)