|
17 | 17 | # pylint:disable=import-error |
18 | 18 |
|
19 | 19 | from collections import OrderedDict |
20 | | -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
| 20 | +from typing import Any, Callable, Dict, List, NamedTuple, Optional |
| 21 | +from typing import OrderedDict as OrderedDictType |
| 22 | +from typing import Tuple, Union |
21 | 23 |
|
22 | 24 | import torch |
23 | 25 |
|
|
57 | 59 | "get_default_gptq_config", |
58 | 60 | "HQQConfig", |
59 | 61 | "get_default_hqq_config", |
| 62 | + "get_woq_tuning_config", |
60 | 63 | ] |
61 | 64 |
|
62 | 65 |
|
@@ -839,7 +842,7 @@ def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[st |
839 | 842 |
|
840 | 843 | def to_config_mapping( |
841 | 844 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None |
842 | | - ) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: |
| 845 | + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: |
843 | 846 | if is_ipex_imported(): |
844 | 847 | return super().to_config_mapping(config_list, model_info) |
845 | 848 | config_mapping = OrderedDict({self.name: self}) |
@@ -1140,3 +1143,23 @@ def get_default_fp8_config_set() -> FP8Config: |
1140 | 1143 | def get_all_registered_configs() -> Dict[str, BaseConfig]: |
1141 | 1144 | registered_configs = config_registry.get_all_configs() |
1142 | 1145 | return registered_configs.get(FRAMEWORK_NAME, {}) |
| 1146 | + |
| 1147 | + |
| 1148 | +# ============================================================================= |
| 1149 | +# Tuning Config |
| 1150 | +# ============================================================================= |
| 1151 | + |
| 1152 | + |
| 1153 | +######################## WOQ Tuning Config ############################### |
| 1154 | +def get_woq_tuning_config() -> list: |
| 1155 | + """Generate the config set for WOQ tuning. |
| 1156 | +
|
| 1157 | + Returns: |
| 1158 | + the list of WOQ quant config. |
| 1159 | + """ |
| 1160 | + RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32) |
| 1161 | + GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32) |
| 1162 | + GPTQ_G32ASYM_DISABLE_LAST_LINEAR = GPTQConfig(use_sym=False).set_local("*.lm_head", GPTQConfig(dtype="fp32")) |
| 1163 | + GPTQ_G128ASYM = GPTQConfig(group_size=128, use_sym=False) |
| 1164 | + AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32) |
| 1165 | + return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_LINEAR, GPTQ_G128ASYM, AWQ_G32ASYM] |
0 commit comments