Skip to content

add docstring for torch.quantization and torch.utils #1928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@ def rtn_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply rtn quantization."""
"""The main entry to apply rtn quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], RTNConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
from neural_compressor.torch.algorithms.weight_only.rtn import RTNQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save

Expand Down Expand Up @@ -110,6 +119,16 @@ def gptq_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply gptq quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], GPTQConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the GPTQ algorithm.")
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save
Expand Down Expand Up @@ -164,6 +183,16 @@ def static_quant_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply static quantization, includes pt2e quantization and ipex quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], StaticQuantConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
if not is_ipex_imported():
return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs)
logger.info("Quantize model with the static quant algorithm.")
Expand Down Expand Up @@ -207,7 +236,23 @@ def static_quant_entry(
###################### PT2E Dynamic Quant Algo Entry ##################################
@register_algo(name=PT2E_DYNAMIC_QUANT)
@torch.no_grad()
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
def pt2e_dynamic_quant_entry(
model: torch.nn.Module,
configs_mapping,
mode: Mode,
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply pt2e dynamic quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping: per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
Expand All @@ -230,7 +275,23 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
###################### PT2E Static Quant Algo Entry ##################################
@register_algo(name=PT2E_STATIC_QUANT)
@torch.no_grad()
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
def pt2e_static_quant_entry(
model: torch.nn.Module,
configs_mapping,
mode: Mode,
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply pt2e static quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping: per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save
Expand Down Expand Up @@ -259,6 +320,16 @@ def smooth_quant_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply smooth quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], SmoothQuantConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the smooth quant algorithm.")
from neural_compressor.torch.algorithms.smooth_quant import SmoothQuantQuantizer, TorchSmoothQuant

Expand Down Expand Up @@ -318,6 +389,16 @@ def awq_quantize_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply AWQ quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], AWQConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the AWQ algorithm.")
from neural_compressor.torch.algorithms.weight_only.awq import AWQQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save
Expand Down Expand Up @@ -383,8 +464,22 @@ def awq_quantize_entry(
###################### TEQ Algo Entry ##################################
@register_algo(name=TEQ)
def teq_quantize_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], mode: Mode, *args, **kwargs
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str, callable], TEQConfig],
mode: Mode,
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply TEQ quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], TEQConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
from neural_compressor.torch.algorithms.weight_only.save_load import save
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer

Expand Down Expand Up @@ -445,6 +540,16 @@ def autoround_quantize_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply AutoRound quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save

Expand Down Expand Up @@ -522,6 +627,16 @@ def hqq_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply AutoRound quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save

Expand Down Expand Up @@ -564,6 +679,16 @@ def mx_quant_entry(
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply AutoRound quantization.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], AutoRoundConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
logger.info("Quantize model with the mx quant algorithm.")
from neural_compressor.torch.algorithms.mx_quant.mx import MXQuantizer

Expand All @@ -578,8 +703,21 @@ def mx_quant_entry(
###################### Mixed Precision Algo Entry ##################################
@register_algo(MIX_PRECISION)
def mix_precision_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str], MixPrecisionConfig],
*args,
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply Mixed Precision.

Args:
model (torch.nn.Module): raw fp32 model or prepared model.
configs_mapping (Dict[Tuple[str, callable], MixPrecisionConfig]): per-op configuration.
mode (Mode, optional): select from [PREPARE, CONVERT and QUANTIZE]. Defaults to Mode.QUANTIZE.

Returns:
torch.nn.Module: prepared model or quantized model.
"""
# only support fp16 and bf16 now, more types might be added later
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter

Expand Down
10 changes: 10 additions & 0 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,23 @@


def get_rtn_double_quant_config_set() -> List[RTNConfig]:
"""Generate RTN double quant config set.

Returns:
List[RTNConfig]: a set of quant config
"""
rtn_double_quant_config_set = []
for double_quant_type, double_quant_config in constants.DOUBLE_QUANT_CONFIGS.items():
rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config))
return rtn_double_quant_config_set


def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
"""Generate all quant config set.

Returns:
Union[BaseConfig, List[BaseConfig]]: a set of quant config
"""
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)


Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@


class OperatorConfig(NamedTuple):
"""OperatorConfig."""

config: BaseConfig
operators: List[Union[str, Callable]]
valid_func_list: List[Callable] = []


class TorchBaseConfig(BaseConfig):
"""Base config class for torch backend."""

# re-write func _get_op_name_op_type_config to fallback op_type with string
# because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
def _get_op_name_op_type_config(self):
Expand Down
13 changes: 13 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@

################ Check imported sys.module first to decide behavior #################
def is_ipex_imported() -> bool:
"""Check whether intel_extension_for_pytorch is imported."""
for name, _ in sys.modules.items():
if name == "intel_extension_for_pytorch":
return True
return False


def is_transformers_imported() -> bool:
"""Check whether transformers is imported."""
for name, _ in sys.modules.items():
if name == "transformers":
return True
Expand All @@ -37,6 +39,11 @@ def is_transformers_imported() -> bool:

################ Check available sys.module to decide behavior #################
def is_package_available(package_name):
"""Check if the package exists in the environment without importing.

Args:
package_name (str): package name
"""
from importlib.util import find_spec

package_spec = find_spec(package_name)
Expand All @@ -52,6 +59,7 @@ def is_package_available(package_name):


def is_hpex_available():
"""Returns whether hpex is available."""
return _hpex_available


Expand All @@ -63,10 +71,12 @@ def is_hpex_available():


def is_ipex_available():
"""Return whether ipex is available."""
return _ipex_available


def get_ipex_version():
"""Return ipex version if ipex exists."""
if is_ipex_available():
try:
import intel_extension_for_pytorch as ipex
Expand All @@ -84,6 +94,7 @@ def get_ipex_version():


def get_torch_version():
"""Return torch version if ipex exists."""
try:
torch_version = torch.__version__.split("+")[0]
except ValueError as e: # pragma: no cover
Expand All @@ -96,6 +107,7 @@ def get_torch_version():


def get_accelerator(device_name="auto"):
"""Return the recommended accelerator based on device priority."""
global accelerator # update the global accelerator when calling this func
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

Expand All @@ -109,6 +121,7 @@ def get_accelerator(device_name="auto"):

# for habana ease-of-use
def device_synchronize(raw_func):
"""Function decorator that calls accelerated.synchronize before and after a function call."""
from functools import wraps

@wraps(raw_func)
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def set_module(model, op_name, new_module):


def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
"""Get model info according to white_module_list."""
module_dict = dict(model.named_modules())
filter_result = []
filter_result_set = set()
Expand All @@ -129,6 +130,11 @@ def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) ->


def get_double_quant_config_dict(double_quant_type="BNB_NF4"):
"""Query config dict of double_quant according to double_quant_type.

Args:
double_quant_type (str, optional): double_quant type. Defaults to "BNB_NF4".
"""
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS

assert double_quant_type in DOUBLE_QUANT_CONFIGS, "Supported double quant configs: {}".format(
Expand Down
Loading