Skip to content

Commit a4675c7

Browse files
authored
Add WOQ tuning (#1782)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent ec49a29 commit a4675c7

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

neural_compressor/torch/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
FP8Config,
3535
get_default_fp8_config,
3636
get_default_fp8_config_set,
37+
get_woq_tuning_config,
3738
)
3839

3940
from neural_compressor.torch.quantization.autotune import (

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def awq_quantize_entry(
308308
use_full_range = op_config.use_full_range
309309

310310
run_fn = kwargs.get("run_fn", None)
311+
run_args = kwargs.get("run_args", None)
311312
example_inputs = kwargs.get("example_inputs", None)
312313
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."
313314

@@ -318,6 +319,7 @@ def awq_quantize_entry(
318319
bits=-1, # no quantize for op not in weight_config
319320
example_inputs=example_inputs, # must be required
320321
run_fn=run_fn,
322+
run_args=run_args,
321323
use_auto_scale=use_auto_scale,
322324
use_mse_search=use_mse_search,
323325
folding=folding,

neural_compressor/torch/quantization/config.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
# pylint:disable=import-error
1818

1919
from collections import OrderedDict
20-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
20+
from typing import Any, Callable, Dict, List, NamedTuple, Optional
21+
from typing import OrderedDict as OrderedDictType
22+
from typing import Tuple, Union
2123

2224
import torch
2325

@@ -57,6 +59,7 @@
5759
"get_default_gptq_config",
5860
"HQQConfig",
5961
"get_default_hqq_config",
62+
"get_woq_tuning_config",
6063
]
6164

6265

@@ -839,7 +842,7 @@ def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[st
839842

840843
def to_config_mapping(
841844
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
842-
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
845+
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
843846
if is_ipex_imported():
844847
return super().to_config_mapping(config_list, model_info)
845848
config_mapping = OrderedDict({self.name: self})
@@ -1140,3 +1143,23 @@ def get_default_fp8_config_set() -> FP8Config:
11401143
def get_all_registered_configs() -> Dict[str, BaseConfig]:
11411144
registered_configs = config_registry.get_all_configs()
11421145
return registered_configs.get(FRAMEWORK_NAME, {})
1146+
1147+
1148+
# =============================================================================
1149+
# Tuning Config
1150+
# =============================================================================
1151+
1152+
1153+
######################## WOQ Tuning Config ###############################
1154+
def get_woq_tuning_config() -> list:
1155+
"""Generate the config set for WOQ tuning.
1156+
1157+
Returns:
1158+
the list of WOQ quant config.
1159+
"""
1160+
RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32)
1161+
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
1162+
GPTQ_G32ASYM_DISABLE_LAST_LINEAR = GPTQConfig(use_sym=False).set_local("*.lm_head", GPTQConfig(dtype="fp32"))
1163+
GPTQ_G128ASYM = GPTQConfig(group_size=128, use_sym=False)
1164+
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
1165+
return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_LINEAR, GPTQ_G128ASYM, AWQ_G32ASYM]

test/3x/torch/test_autotune.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,30 @@ def eval_acc_fn(model) -> float:
308308
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
309309
self.assertIsNone(best_model)
310310

311+
def test_woq_tuning(self):
312+
from neural_compressor.torch.quantization import autotune, get_woq_tuning_config
313+
314+
baseline = [1]
315+
acc_res_lst = baseline + [0.9, 0.95, 0.95, 0.99, 1.1]
316+
317+
def eval_acc_fn(model):
318+
res = acc_res_lst.pop(0)
319+
return res
320+
321+
custom_tune_config = TuningConfig(config_set=get_woq_tuning_config(), tolerable_loss=-1)
322+
example_inputs = torch.ones([1, 32], dtype=torch.long)
323+
model = get_gpt_j()
324+
dataloader = GPTQLLMDataLoader()
325+
best_model = autotune(
326+
model=model,
327+
tune_config=custom_tune_config,
328+
eval_fn=eval_acc_fn,
329+
run_fn=run_fn_for_gptq,
330+
run_args=(dataloader, True), # run_args should be a tuple,
331+
example_inputs=example_inputs,
332+
)
333+
self.assertIsNone(best_model)
334+
311335

312336
if __name__ == "__main__":
313337
unittest.main()

0 commit comments

Comments
 (0)