Skip to content

Commit 15ebfd0

Browse files
committed
Add unit test case for custom quantization config
Signed-off-by: ice-tong <xych6@outlook.com>
1 parent 49f7ab1 commit 15ebfd0

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests register custom quantization config.
2+
3+
See https://github.com/vllm-project/vllm/issues/11926 for more details.
4+
5+
Run `pytest tests/quantization/test_register_quantization_config.py`.
6+
"""
7+
from typing import Any, Dict, List, Optional
8+
9+
import pytest
10+
import torch
11+
import torch.nn.functional as F
12+
13+
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
14+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
15+
from vllm.model_executor.layers.quantization import (
16+
get_quantization_config, register_quantization_config)
17+
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
18+
QuantizationConfig)
19+
20+
21+
class FakeQuantLinearMethod(UnquantizedLinearMethod):
22+
"""Fake quantization linear method for per-token dynamic quantization."""
23+
24+
def __init__(self, num_bits: int = 8) -> None:
25+
"""Initialize the quantization method."""
26+
super().__init__()
27+
self.num_bits = num_bits
28+
29+
def apply(self,
30+
layer: "torch.nn.Module",
31+
x: "torch.Tensor",
32+
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
33+
"""Perform fake quantization before the linear layer."""
34+
35+
# Calculate the scales dynamically
36+
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
37+
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
38+
scales = (max_val - min_val) / (2**self.num_bits - 1)
39+
40+
# Fake quantize the input
41+
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
42+
2**(self.num_bits - 1) - 1)
43+
dequant_x = quant_x * scales
44+
45+
return F.linear(dequant_x, layer.weight, bias)
46+
47+
48+
@register_quantization_config("custom_quant")
49+
class CustomQuantConfig(QuantizationConfig):
50+
"""Custom quantization config for per-token dynamic fake quantization."""
51+
52+
def __init__(self, num_bits: int = 8) -> None:
53+
"""Initialize the quantization config."""
54+
self.num_bits = num_bits
55+
56+
def get_name(self) -> str:
57+
"""Name of the quantization method."""
58+
return "custom_quant"
59+
60+
def get_supported_act_dtypes(self) -> List["torch.dtype"]:
61+
"""List of supported activation dtypes."""
62+
return [torch.float16, torch.bfloat16]
63+
64+
@classmethod
65+
def get_min_capability(cls) -> int:
66+
"""Minimum GPU capability to support the quantization method."""
67+
return -1
68+
69+
@staticmethod
70+
def get_config_filenames() -> List[str]:
71+
"""List of filenames to search for in the model directory."""
72+
return []
73+
74+
@classmethod
75+
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
76+
"""Create a config class from the model's quantization config."""
77+
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
78+
79+
def get_quant_method(self, layer: "torch.nn.Module",
80+
prefix: str) -> Optional["FakeQuantLinearMethod"]:
81+
"""Get the quantize method to use for the quantized layer."""
82+
if isinstance(layer, LinearBase):
83+
return FakeQuantLinearMethod(num_bits=self.num_bits)
84+
return None
85+
86+
87+
def test_register_quantization_config():
88+
"""Test register custom quantization config."""
89+
90+
# The quantization method `custom_quant` should be registered.
91+
assert get_quantization_config("custom_quant") == CustomQuantConfig
92+
93+
# The quantization method `custom_quant` is already exists,
94+
# should raise an error.
95+
with pytest.raises(ValueError):
96+
register_quantization_config("custom_quant")(CustomQuantConfig)
97+
98+
99+
@pytest.mark.parametrize(argnames="model",
100+
argvalues=[
101+
"meta-llama/Meta-Llama-3-8B-Instruct",
102+
])
103+
def test_custom_quant(vllm_runner, model):
104+
"""Test infer with the custom quantization method."""
105+
with vllm_runner(model_name=model,
106+
quantization="custom_quant",
107+
enforce_eager=True) as llm:
108+
109+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
110+
layer = model.model.layers[0]
111+
qkv_proj = layer.self_attn.qkv_proj
112+
113+
# Check the quantization method is FakeQuantLinearMethod
114+
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
115+
116+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
117+
assert output

0 commit comments

Comments
 (0)