Skip to content

implement TorchBaseConfig #1911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,26 @@ class OperatorConfig(NamedTuple):
valid_func_list: List[Callable] = []


class TorchBaseConfig(BaseConfig):
# re-write func _get_op_name_op_type_config to fallback op_type with string
# because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
def _get_op_name_op_type_config(self):
op_type_config_dict = dict()
op_name_config_dict = dict()
for name, config in self.local_config.items():
if self._is_op_type(name):
# Convert the Callable to String.
new_name = self._op_type_to_str(name)
op_type_config_dict[new_name] = config
else:
op_name_config_dict[name] = config
op_type_config_dict[name] = config
return op_type_config_dict, op_name_config_dict


######################## RNT Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
class RTNConfig(BaseConfig):
class RTNConfig(TorchBaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

name = RTN
Expand Down Expand Up @@ -238,7 +255,7 @@ def get_default_double_quant_config(type="BNB_NF4"):

######################## GPTQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ)
class GPTQConfig(BaseConfig):
class GPTQConfig(TorchBaseConfig):
"""Config class for GPTQ.

GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
Expand Down Expand Up @@ -390,7 +407,7 @@ def get_default_gptq_config() -> GPTQConfig:

######################## AWQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AWQ, priority=PRIORITY_AWQ)
class AWQConfig(BaseConfig):
class AWQConfig(TorchBaseConfig):
"""Config class for AWQ.

AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Expand Down Expand Up @@ -532,7 +549,7 @@ def get_default_awq_config() -> AWQConfig:

######################## TEQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=TEQ, priority=PRIORITY_TEQ)
class TEQConfig(BaseConfig):
class TEQConfig(TorchBaseConfig):
"""Config class for TEQ.

TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Expand Down Expand Up @@ -670,7 +687,7 @@ def get_default_teq_config() -> TEQConfig:

######################## AUTOROUND Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AUTOROUND, priority=PRIORITY_AUTOROUND)
class AutoRoundConfig(BaseConfig):
class AutoRoundConfig(TorchBaseConfig):
"""Config class for AUTOROUND.

AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs.
Expand Down Expand Up @@ -815,7 +832,7 @@ def get_default_AutoRound_config() -> AutoRoundConfig:

######################## MX Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MX_QUANT)
class MXQuantConfig(BaseConfig):
class MXQuantConfig(TorchBaseConfig):
"""Config class for MX quantization."""

supported_configs: List[OperatorConfig] = []
Expand Down Expand Up @@ -928,7 +945,7 @@ def get_default_mx_config() -> MXQuantConfig:

######################## Dynamic Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT)
class DynamicQuantConfig(BaseConfig):
class DynamicQuantConfig(TorchBaseConfig):
"""Config class for dynamic quantization."""

name = PT2E_DYNAMIC_QUANT
Expand Down Expand Up @@ -1002,7 +1019,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig:

######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
class StaticQuantConfig(TorchBaseConfig):
"""Config class for static quantization."""

name = STATIC_QUANT
Expand Down Expand Up @@ -1091,7 +1108,7 @@ def get_default_static_config() -> StaticQuantConfig:

######################## Smooth Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
class SmoothQuantConfig(BaseConfig):
class SmoothQuantConfig(TorchBaseConfig):
"""Config class for smooth quantization."""

name = SMOOTH_QUANT
Expand Down Expand Up @@ -1205,7 +1222,7 @@ def get_default_sq_config() -> SmoothQuantConfig:

######################## HQQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ)
class HQQConfig(BaseConfig):
class HQQConfig(TorchBaseConfig):
# Half-Quadratic Quantization (HQQ), more details:
# Blog: https://mobiusml.github.io/hqq_blog/
# Code: https://github.com/mobiusml/hqq
Expand Down Expand Up @@ -1286,7 +1303,7 @@ def get_default_hqq_config() -> HQQConfig:

######################## FP8 Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT)
class FP8Config(BaseConfig):
class FP8Config(TorchBaseConfig):
"""Config class for FP8 quantization."""

name = FP8_QUANT
Expand Down Expand Up @@ -1381,7 +1398,7 @@ def get_default_fp8_config_set() -> FP8Config:

######################## MixPrecision Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
class MixPrecisionConfig(BaseConfig):
class MixPrecisionConfig(TorchBaseConfig):
"""Config class for mix-precision."""

name = MIX_PRECISION
Expand Down
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_static_quant_fallback(self):
quant_config = get_default_static_config()
example_inputs = self.input
# fallback by op_type
quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
quant_config.set_local([torch.nn.Linear, "Linear&add"], StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
Expand Down
Loading