Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Initial Activation Quantization Support #4525

Merged
merged 49 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
4d27a2c
Initial `CompressedTensors` config + Activation Quantization support …
dsikka Apr 30, 2024
92b3703
add get_quant method to compressed tensors config
dsikka Apr 30, 2024
2a3eb83
small rebase fixed
dsikka Apr 30, 2024
3dd1fe8
format
dsikka Apr 30, 2024
f2f8c52
fix mypy complaints
Apr 30, 2024
c9308eb
Merge branch 'main' into ds-quant
dsikka Apr 30, 2024
d9d49b5
format fixes
dsikka Apr 30, 2024
b111ee6
Merge branch 'main' into ds-quant
dsikka May 1, 2024
c31a7af
format fix post rebase
dsikka May 1, 2024
ca01b39
lazy import CompressedTensorsW8A8StaticTensor (#220)
varun-sundar-rabindranath May 1, 2024
f0197d4
lazy cutlass_gemm_dq import (#221)
varun-sundar-rabindranath May 1, 2024
4624b46
fix asm
May 1, 2024
75757d5
update shape change
dsikka May 2, 2024
e1df0eb
add todo
dsikka May 2, 2024
bc0991c
Rename quant_per_tensor -> static_scaled_int8_quant
May 2, 2024
74ad650
Remove cruft
May 2, 2024
43c43f3
Merge branch 'main' into ds-quant
dsikka May 14, 2024
cf5600f
fixes : typo
May 14, 2024
169ce7f
py-cutlass temporary hack for num_prompts==1
May 15, 2024
03b53e7
yapf
May 15, 2024
f9df31b
add test_int8_quant
May 16, 2024
ba4b6b3
call cpp cutlass
May 17, 2024
3c223c6
Merge branch 'main' into ds-quant
dsikka May 17, 2024
b27f31a
remove cutlass py interface
May 17, 2024
b589cdd
format.sh
May 17, 2024
98159cf
remove fake-quant
May 17, 2024
8dbeb31
add compressed tensors test
dsikka May 17, 2024
5eeb40a
remove torch.int8
dsikka May 17, 2024
c55e023
format
dsikka May 17, 2024
f5cbbd3
fix config parsing to match new model
dsikka May 20, 2024
a685957
revert parsing to use default pathway
dsikka May 20, 2024
4dfb37f
PR comments
dsikka May 21, 2024
de81f9e
Fix scales/zero-points device allocation
May 21, 2024
15f1863
ruff
May 21, 2024
bd53847
add better comments
May 21, 2024
b2926f3
add comment
dsikka May 22, 2024
1274386
Merge branch 'main' into ds-quant
dsikka May 22, 2024
18640c8
clang format
dsikka May 22, 2024
5c5dc84
clang format again
dsikka May 22, 2024
a44b4a0
address PR comments
May 22, 2024
6f0e6e1
clang-format
May 22, 2024
0090454
remove layer name
dsikka May 23, 2024
4b10fd7
remove unused import
dsikka May 23, 2024
68a59c7
remove parent name
dsikka May 23, 2024
b0afe67
Fix rounding
May 22, 2024
4f4951e
comment
May 23, 2024
869de3f
cruft
May 23, 2024
e68e391
yapf
May 23, 2024
d77cf50
remove unquantized check
dsikka May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unclear to me about the name compressed_tensors. I suppose this is the official method name of SparseML? Then can we just use sparseml here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compressed-tensors is the name of the package responsible for saving quantized and sparse models

So the flow is:

  • use SparseML to apply quantization / sparsity
  • save model to safetensors with a compressed-tensors config
  • load + run in vllm

"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,

#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

Expand Down
3 changes: 3 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");

ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
Expand Down
59 changes: 59 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cmath>

#include "../../dispatch_utils.h"

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}

namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
});
}
31 changes: 31 additions & 0 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from vllm._C import ops

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
seed: int, scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

out1 = (x / scale).round().clamp(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
36 changes: 36 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Test model set-up and weight loading for sparseml-quantized models.
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""

import torch

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8

assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32
18 changes: 18 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,24 @@ def scaled_fp8_quant(
return output, scale


# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.

Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.

Returns:
torch.Tensor: Output tensor in int8.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
Expand Down
Loading
Loading