Skip to content

Commit bc825b9

Browse files
jeejeeleeminpeter
authored andcommitted
[Quantization] Improve AWQ logic (vllm-project#19431)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 849eb1f commit bc825b9

File tree

1 file changed

+38
-4
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+38
-4
lines changed

vllm/model_executor/layers/quantization/awq.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Any, Optional
4+
from typing import Any, Optional, Union
55

66
import torch
77

88
from vllm import _custom_ops as ops
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
911
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1012
UnquantizedLinearMethod)
1113
from vllm.model_executor.layers.quantization import QuantizationMethods
1214
from vllm.model_executor.layers.quantization.base_config import (
13-
QuantizationConfig)
15+
QuantizationConfig, QuantizeMethodBase)
1416
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1517
PackedvLLMParameter)
1618

19+
logger = init_logger(__name__)
20+
1721

1822
class AWQConfig(QuantizationConfig):
1923
"""Config class for AWQ.
@@ -74,12 +78,42 @@ def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
7478
config, ["modules_to_not_convert"], None)
7579
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
7680

77-
def get_quant_method(self, layer: torch.nn.Module,
78-
prefix: str) -> Optional["LinearMethodBase"]:
81+
def get_quant_method(
82+
self, layer: torch.nn.Module, prefix: str
83+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
7984
if isinstance(layer, LinearBase):
8085
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
8186
return UnquantizedLinearMethod()
8287
return AWQLinearMethod(self)
88+
elif isinstance(layer, FusedMoE):
89+
# Lazy import to avoid circular import.
90+
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
91+
from .moe_wna16 import MoeWNA16Config
92+
from .utils.marlin_utils import check_moe_marlin_supports_layer
93+
if not check_moe_marlin_supports_layer(layer, self.group_size):
94+
logger.warning_once(
95+
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
96+
"Falling back to Moe WNA16 kernels.")
97+
config = {
98+
"quant_method": "awq",
99+
"bits": self.weight_bits,
100+
"group_size": self.group_size,
101+
"zero_point": self.zero_point,
102+
"lm_head": False,
103+
}
104+
return MoeWNA16Config.from_config(config).get_quant_method(
105+
layer, prefix)
106+
marlin_compatible_config_dict = {
107+
"quant_method": "awq",
108+
"bits": self.weight_bits,
109+
"group_size": self.group_size,
110+
"zero_point": self.zero_point,
111+
"lm_head": False,
112+
"modules_to_not_convert": self.modules_to_not_convert,
113+
}
114+
awq_marlin_config = AWQMarlinConfig.from_config(
115+
marlin_compatible_config_dict)
116+
return AWQMoEMethod(awq_marlin_config)
83117
return None
84118

85119

0 commit comments

Comments
 (0)