|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 |
| -from typing import Any, Optional |
| 4 | +from typing import Any, Optional, Union |
5 | 5 |
|
6 | 6 | import torch
|
7 | 7 |
|
8 | 8 | from vllm import _custom_ops as ops
|
| 9 | +from vllm.logger import init_logger |
| 10 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
9 | 11 | from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
10 | 12 | UnquantizedLinearMethod)
|
11 | 13 | from vllm.model_executor.layers.quantization import QuantizationMethods
|
12 | 14 | from vllm.model_executor.layers.quantization.base_config import (
|
13 |
| - QuantizationConfig) |
| 15 | + QuantizationConfig, QuantizeMethodBase) |
14 | 16 | from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
15 | 17 | PackedvLLMParameter)
|
16 | 18 |
|
| 19 | +logger = init_logger(__name__) |
| 20 | + |
17 | 21 |
|
18 | 22 | class AWQConfig(QuantizationConfig):
|
19 | 23 | """Config class for AWQ.
|
@@ -74,12 +78,42 @@ def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
74 | 78 | config, ["modules_to_not_convert"], None)
|
75 | 79 | return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
76 | 80 |
|
77 |
| - def get_quant_method(self, layer: torch.nn.Module, |
78 |
| - prefix: str) -> Optional["LinearMethodBase"]: |
| 81 | + def get_quant_method( |
| 82 | + self, layer: torch.nn.Module, prefix: str |
| 83 | + ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: |
79 | 84 | if isinstance(layer, LinearBase):
|
80 | 85 | if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
81 | 86 | return UnquantizedLinearMethod()
|
82 | 87 | return AWQLinearMethod(self)
|
| 88 | + elif isinstance(layer, FusedMoE): |
| 89 | + # Lazy import to avoid circular import. |
| 90 | + from .awq_marlin import AWQMarlinConfig, AWQMoEMethod |
| 91 | + from .moe_wna16 import MoeWNA16Config |
| 92 | + from .utils.marlin_utils import check_moe_marlin_supports_layer |
| 93 | + if not check_moe_marlin_supports_layer(layer, self.group_size): |
| 94 | + logger.warning_once( |
| 95 | + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " |
| 96 | + "Falling back to Moe WNA16 kernels.") |
| 97 | + config = { |
| 98 | + "quant_method": "awq", |
| 99 | + "bits": self.weight_bits, |
| 100 | + "group_size": self.group_size, |
| 101 | + "zero_point": self.zero_point, |
| 102 | + "lm_head": False, |
| 103 | + } |
| 104 | + return MoeWNA16Config.from_config(config).get_quant_method( |
| 105 | + layer, prefix) |
| 106 | + marlin_compatible_config_dict = { |
| 107 | + "quant_method": "awq", |
| 108 | + "bits": self.weight_bits, |
| 109 | + "group_size": self.group_size, |
| 110 | + "zero_point": self.zero_point, |
| 111 | + "lm_head": False, |
| 112 | + "modules_to_not_convert": self.modules_to_not_convert, |
| 113 | + } |
| 114 | + awq_marlin_config = AWQMarlinConfig.from_config( |
| 115 | + marlin_compatible_config_dict) |
| 116 | + return AWQMoEMethod(awq_marlin_config) |
83 | 117 | return None
|
84 | 118 |
|
85 | 119 |
|
|
0 commit comments