Skip to content

Commit 75a6117

Browse files
bigPYJ1151sumitd2
authored andcommitted
[Hardware][CPU] Support AWQ for CPU backend (vllm-project#7515)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent c50b1b0 commit 75a6117

File tree

9 files changed

+214
-7
lines changed

9 files changed

+214
-7
lines changed

.buildkite/run-cpu-test.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@ docker exec cpu-test bash -c "
2727
pytest -v -s tests/models/decoder_only/language \
2828
--ignore=tests/models/test_fp8.py \
2929
--ignore=tests/models/decoder_only/language/test_jamba.py \
30+
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
3031
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
3132

3233
# Run compressed-tensor test
34+
# docker exec cpu-test bash -c "
35+
# pytest -s -v \
36+
# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
37+
# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
38+
39+
# Run AWQ test
3340
docker exec cpu-test bash -c "
3441
pytest -s -v \
35-
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
36-
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
42+
tests/quantization/test_ipex_quant.py"
3743

3844
# online inference
3945
docker exec cpu-test bash -c "

Dockerfile.cpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
2222

2323
RUN echo 'ulimit -c 0' >> ~/.bashrc
2424

25-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
25+
RUN pip install intel_extension_for_pytorch==2.4.0
2626

2727
WORKDIR /workspace
2828

docs/source/quantization/supported_hardware.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The table below shows the compatibility of various quantization implementations
2828
- ✅︎
2929
- ✗
3030
- ✗
31-
-
31+
- ✅︎
3232
- ✗
3333
- ✗
3434
* - GPTQ
@@ -61,7 +61,7 @@ The table below shows the compatibility of various quantization implementations
6161
- ✅︎
6262
- ✗
6363
- ✗
64-
-
64+
- ✅︎
6565
- ✗
6666
- ✗
6767
* - FP8 (W8A8)

tests/quantization/test_ipex_quant.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Test model set-up and inference for quantized HF models supported
2+
on the CPU backend using IPEX (including AWQ).
3+
4+
Validating the configuration and printing results for manual checking.
5+
6+
Run `pytest tests/quantization/test_ipex_quant.py`.
7+
"""
8+
9+
import pytest
10+
11+
from vllm.platforms import current_platform
12+
13+
MODELS = [
14+
"casperhansen/llama-3-8b-instruct-awq",
15+
]
16+
DTYPE = ["bfloat16"]
17+
18+
19+
@pytest.mark.skipif(not current_platform.is_cpu(),
20+
reason="only supports the CPU backend.")
21+
@pytest.mark.parametrize("model", MODELS)
22+
@pytest.mark.parametrize("dtype", DTYPE)
23+
def test_ipex_quant(vllm_runner, model, dtype):
24+
with vllm_runner(model, dtype=dtype) as llm:
25+
output = llm.generate_greedy(["The capital of France is"],
26+
max_tokens=32)
27+
assert output
28+
print(output)

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
2828
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
2929
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
30-
"ModelOptFp8LinearMethod"
30+
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
3131
]
3232

3333

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
GPTQMarlinConfig)
2222
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
2323
GPTQMarlin24Config)
24+
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
2425
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
2526
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
2627
from vllm.model_executor.layers.quantization.neuron_quant import (
@@ -49,6 +50,7 @@
4950
"qqq": QQQConfig,
5051
"experts_int8": ExpertsInt8Config,
5152
"neuron_quant": NeuronQuantConfig,
53+
"ipex": IPEXConfig,
5254
}
5355

5456

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2121
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
2222
PackedvLLMParameter)
23+
from vllm.platforms import current_platform
2324
from vllm.scalar_type import scalar_types
2425

2526
logger = init_logger(__name__)
@@ -123,6 +124,9 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
123124
group_size = quant_config.get("group_size")
124125
has_zp = quant_config.get("zero_point")
125126

127+
if not current_platform.is_cuda():
128+
return False
129+
126130
if quant_method != "awq":
127131
return False
128132

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
import torch
4+
5+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
6+
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
7+
from vllm.model_executor.layers.quantization.base_config import (
8+
QuantizationConfig)
9+
from vllm.platforms import current_platform
10+
11+
12+
class IPEXConfig(QuantizationConfig):
13+
"""INT8 quantization config class using IPEX for the CPU backend,
14+
including AWQ.
15+
"""
16+
17+
IPEX_QUANT_METHOD_MAP = {
18+
"awq": 1,
19+
"gptq": 2,
20+
}
21+
22+
def __init__(
23+
self,
24+
method: str,
25+
weight_bits: int,
26+
group_size: int,
27+
) -> None:
28+
self.method = method
29+
self.weight_bits = weight_bits
30+
self.group_size = group_size
31+
self.pack_factor = 32 // self.weight_bits
32+
33+
if self.weight_bits not in [4]:
34+
raise ValueError(f"IPEX quantization supports weight bits [4], "
35+
f"but got {self.weight_bits}.")
36+
37+
if self.method == "awq":
38+
self.quant_method = IPEXAWQLinearMethod
39+
else:
40+
raise ValueError(f"IPEX quantization supports [awq], "
41+
f"but got {self.method}.")
42+
43+
def __repr__(self) -> str:
44+
return (f"IPEXConfig(method={self.method}"
45+
f"weight_bits={self.weight_bits}, "
46+
f"group_size={self.group_size}")
47+
48+
def get_ipex_quant_method_id(self) -> int:
49+
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
50+
51+
@classmethod
52+
def get_name(cls) -> str:
53+
return "ipex"
54+
55+
@classmethod
56+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
57+
return [torch.bfloat16]
58+
59+
@classmethod
60+
def get_min_capability(cls) -> int:
61+
return -1
62+
63+
@staticmethod
64+
def get_config_filenames() -> List[str]:
65+
return [
66+
"quant_config.json",
67+
"quantize_config.json",
68+
]
69+
70+
@classmethod
71+
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
72+
method = cls.get_from_keys(config, ["quant_method"]).lower()
73+
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
74+
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
75+
return cls(method, weight_bits, group_size)
76+
77+
@classmethod
78+
def override_quantization_method(cls, hf_quant_cfg,
79+
user_quant) -> Optional[str]:
80+
if not current_platform.is_cpu():
81+
return None
82+
83+
quant_method = hf_quant_cfg.get("quant_method", "").lower()
84+
85+
if quant_method in ["awq"]:
86+
return cls.get_name()
87+
88+
return None
89+
90+
def get_quant_method(self, layer: torch.nn.Module,
91+
prefix: str) -> Optional["LinearMethodBase"]:
92+
if isinstance(layer, LinearBase):
93+
return self.quant_method(self)
94+
return None
95+
96+
def get_scaled_act_names(self) -> List[str]:
97+
if self.method == "awq":
98+
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
99+
else:
100+
return []
101+
102+
103+
class IPEXAWQLinearMethod(AWQLinearMethod):
104+
"""AWQ linear method using IPEX for the CPU backend.
105+
"""
106+
107+
def __init__(self, quant_config: IPEXConfig):
108+
self.quant_config = quant_config # type: ignore
109+
110+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
111+
super().process_weights_after_loading(layer=layer)
112+
113+
bias = layer.bias if not layer.skip_bias_add else None
114+
115+
try:
116+
import intel_extension_for_pytorch as ipex
117+
if ipex.__version__ < "2.4.0":
118+
raise ImportError("intel_extension_for_pytorch version is "
119+
"wrong. Please install "
120+
"intel_extension_for_pytorch>=2.4.0.")
121+
except ImportError as err:
122+
raise ImportError(
123+
"Please install "
124+
"intel_extension_for_pytorch>=2.4.0 via "
125+
"`pip install intel_extension_for_pytorch>=2.4.0`"
126+
" to use IPEX-AWQ linear method.") from err
127+
128+
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
129+
# with better performance.
130+
lowp_mode = ipex.quantization.WoqLowpMode.INT8
131+
# The weight will be de-packed from INT4 to INT8.
132+
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
133+
# The float activation will be quantized (dynamic, per-token) to INT8.
134+
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
135+
136+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
137+
weight_dtype=weight_dtype,
138+
lowp_mode=lowp_mode,
139+
act_quant_mode=act_quant_mode,
140+
group_size=self.quant_config.group_size,
141+
)
142+
143+
layer.ipex_output_size = layer.qweight.size(
144+
1) * self.quant_config.pack_factor
145+
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\
146+
WeightOnlyQuantizedLinear.from_weight(
147+
layer.qweight,
148+
layer.scales,
149+
layer.qzeros,
150+
layer.qweight.size(0),
151+
layer.ipex_output_size,
152+
qconfig=qconfig,
153+
bias=bias,
154+
group_size=self.quant_config.group_size,
155+
quant_method=
156+
self.quant_config.get_ipex_quant_method_id() # type: ignore
157+
)
158+
159+
def apply(self,
160+
layer: torch.nn.Module,
161+
x: torch.Tensor,
162+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
163+
reshaped_x = x.reshape(-1, x.shape[-1])
164+
out = layer.ipex_qlinear(reshaped_x)
165+
166+
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))

vllm/worker/cpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def _is_encoder_decoder_model(self):
215215
def init_device(self) -> None:
216216
if self.local_omp_cpuid != "all":
217217
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
218-
logger.info(ret)
218+
if ret:
219+
logger.info(ret)
219220

220221
self.init_distributed_environment()
221222
# Set random seed.

0 commit comments

Comments
 (0)