Skip to content

Support calib_func on TF 3x API #1934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/3x/TensorFlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ Intel(R) Neural Compressor provides `quantize_model` and `autotune` as main inte

**quantize_model**

The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader` and `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot.
The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader`, `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot.

```python
def quantize_model(
model: Union[str, tf.keras.Model, BaseModel],
quant_config: Union[BaseConfig, list],
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
```
`model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class.
Expand All @@ -41,6 +42,9 @@ def quantize_model(

`calib_iteration` is used to decide how many iterations the calibration process will be run.

`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference.


Here is a simple example of using `quantize_model` interface with a dummy calibration dataloader and the default `StaticQuantConfig`:
```python
from neural_compressor.tensorflow import StaticQuantConfig, quantize_model
Expand Down Expand Up @@ -68,6 +72,7 @@ def autotune(
eval_args: Optional[Tuple[Any]] = None,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
) -> Optional[BaseModel]:
```
`model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class.
Expand All @@ -82,6 +87,8 @@ def autotune(

`calib_iteration` is used to decide how many iterations the calibration process will be run.

`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference.

Here is a simple example of using `autotune` interface with different quantization rules defined by a list of `StaticQuantConfig`:
```python
from neural_compressor.common.base_tuning import TuningConfig
Expand Down
12 changes: 8 additions & 4 deletions neural_compressor/tensorflow/algorithms/smoother/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,23 @@ class SmoothQuant:
def __init__(
self,
config: SmoothQuantConfig,
calib_dataloader: Callable,
calib_dataloader: Callable = None,
calib_iteration: int = 1,
calib_func: Callable = None,
):
"""Convert the model by smooth quant.

Args:
config: the SmoothQuantConfig class used to set this class
calibdataloader: the calibration dataloader
calib_iteration: how many steps of iterations on the dataloader to move forward
config: the SmoothQuantConfig class used to set this class.
calibdataloader: the calibration dataloader.
calib_iteration: how many steps of iterations on the dataloader to move forward.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
model: A smoothed Tensorflow model
"""
assert calib_func is None, "calibration function is not supported for smooth quant."
self.config = config
self.calib_dataloader = calib_dataloader
self.calib_iteration = calib_iteration
Expand Down
12 changes: 6 additions & 6 deletions neural_compressor/tensorflow/algorithms/static_quant/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,18 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5):
return bn_fused_model

@dump_elapsed_time("Pass quantize model")
def quantize(self, quant_config, model, dataloader, iteration, q_func=None):
def quantize(self, quant_config, model, dataloader, iteration, calib_func=None):
"""Execute the quantize process on the specified model.

Args:
tune_cfg(dict): The user defined 'StaticQuantConfig' class.
quant_config(dict): The user defined 'StaticQuantConfig' class.
model (object): The model to do quantization.
dataloader(object): The calibration dataloader used to load quantization dataset.
iteration(int): The iteration of calibration.
q_func (optional): training function for quantization aware training mode.
calib_func (optional): the function used for calibration, should be a substitution for calibration
dataloader when the built-in calibration function of INC does not work for model inference.
"""
assert calib_func is None, "The calibration function is not supported on Keras backend yet"
self.query_fw_capability(model)
converter = KerasConfigConverter(quant_config, iteration)
tune_cfg = converter.parse_to_tune_cfg()
Expand Down Expand Up @@ -367,15 +369,13 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None):

return quantized_model

def _calibrate(self, model, dataloader, calib_interation):
def _calibrate(self, model, dataloader=None, calib_interation=None):
"""Apply calibration.

Args:
model (tf.keras.Model): The model inserted with FakeQuant layers for calibration.
dataloader(object): The calibration dataloader used to load quantization dataset.
iteration(int): The iteration of calibration.
fq_output_layers (dict): A dict mapping from names of FakeQuant layers to
names of their output layers.
"""
# run eagerly to fetch the numpy min/max
results = {}
Expand Down
36 changes: 18 additions & 18 deletions neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def quantize(
model: BaseModel,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
q_func=None,
calib_func: Callable = None,
):
"""Execute the quantize process on the specified model.

Expand All @@ -181,11 +181,11 @@ def quantize(
model: the fp32 model to be quantized.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
q_func: training function for quantization aware training mode,
which not enabled for tensorflow yet.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
tf.compat.v1.GraphDef: the quantized model
converted_model: the quantized INC model wrapper.
"""
assert (
self.approach != "post_training_dynamic_quant"
Expand All @@ -195,7 +195,7 @@ def quantize(
self.approach != "quant_aware_training"
), "Quantize Aware Training is not supported on TensorFlow framework now!"

self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration
self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration if calib_dataloader else 100
tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration)
self._tuning_cfg_to_fw(tune_cfg)
self.bf16_ops.extend(self.smooth_quant_mul_ops)
Expand Down Expand Up @@ -228,7 +228,7 @@ def quantize(
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=calib_dataloader,
calib_func=q_func,
calib_func=calib_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand All @@ -251,7 +251,7 @@ def quantize(
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=calib_dataloader,
calib_func=q_func,
calib_func=calib_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand All @@ -275,7 +275,7 @@ def quantize(
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=calib_dataloader,
calib_func=q_func,
calib_func=calib_func,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
performance_only=self.performance_only,
Expand Down Expand Up @@ -750,21 +750,21 @@ def quantize(
model: BaseModel,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
q_func=None,
calib_func: Callable = None,
):
"""Execute the quantize process on the specified model.

Args:
tune_cfg (dict): quantization configuration
model (tf.compat.v1.GraphDef): fp32 model
data_loader (generator): generator the data and labels
q_func (optional): training function for quantization aware training mode,
which not enabled for tensorflow yet.
quant_config: a quantization configuration.
model: the fp32 model to be quantized.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
tf.compat.v1.GraphDef: the quantized model
converted_model: the quantized INC model wrapper.
"""
assert q_func is None, "quantization aware training mode is not support on tensorflow"
self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration
tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration)
self._tuning_cfg_to_fw(tune_cfg)
Expand Down Expand Up @@ -798,7 +798,7 @@ def quantize(
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=calib_dataloader,
calib_func=q_func,
calib_func=calib_func,
itex_mode=self.itex_mode,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
Expand Down Expand Up @@ -846,7 +846,7 @@ def quantize(
fp32_ops=self.fp32_ops,
bf16_ops=self.bf16_ops,
data_loader=calib_dataloader,
calib_func=q_func,
calib_func=calib_func,
itex_mode=self.itex_mode,
qdq_enabled=self.qdq_enabled,
new_api=self.new_api,
Expand Down
21 changes: 19 additions & 2 deletions neural_compressor/tensorflow/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def static_quant_entry(
quant_config: BaseConfig,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""The main entry to apply static quantization.

Expand All @@ -36,6 +37,8 @@ def static_quant_entry(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
Expand All @@ -49,7 +52,7 @@ def static_quant_entry(
framework = TensorFlowAdaptor

quantizer = framework(TFConfig.global_config)
q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration)
q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration, calib_func)
TFConfig.reset_global_config()

return q_model
Expand All @@ -61,12 +64,26 @@ def smooth_quant_entry(
smooth_quant_config: SmoothQuantConfig,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""The main entry to apply smooth quantization.

Args:
model: a fp32 model to be quantized.
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
"""
assert not isinstance(model, KerasModel), "INC don't support smooth quantization for Keras models now."

from neural_compressor.tensorflow.algorithms import SmoothQuant

converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration)
converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration, calib_func)
sq_model = converter(model)

return sq_model
7 changes: 5 additions & 2 deletions neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def autotune(
eval_args: Optional[Tuple[Any]] = None,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
) -> Optional[BaseModel]:
"""The main entry of auto-tune."""
model = Model(model)
Expand All @@ -57,7 +58,7 @@ def autotune(
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.execution_start()
logger.info(quant_config.to_dict())
q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration, calib_func)
tuning_logger.execution_end()
tuning_logger.evaluation_start()
eval_result: float = eval_func_wrapper.evaluate(q_model)
Expand All @@ -71,7 +72,9 @@ def autotune(
logger.info("Re-quantizing with best quantization config...")
del q_model
best_quant_config: BaseConfig = best_trial_record.quant_config
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration)
best_quant_model = quantize_model(
model, best_quant_config, calib_dataloader, calib_iteration, calib_func
)
else:
best_quant_model = q_model
break
Expand Down
14 changes: 11 additions & 3 deletions neural_compressor/tensorflow/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def quantize_model(
quant_config: Union[BaseConfig, list],
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""The main entry to quantize model.

Expand All @@ -40,16 +41,20 @@ def quantize_model(
quant_config: single or lists of quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
"""
q_model = Model(model)
if isinstance(quant_config, list):
for config in quant_config:
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration)
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration, calib_func)
else:
q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration)
q_model = quantize_model_with_single_config(
q_model, quant_config, calib_dataloader, calib_iteration, calib_func
)

return q_model

Expand All @@ -59,6 +64,7 @@ def quantize_model_with_single_config(
quant_config: BaseConfig,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""Quantize model using single config.

Expand All @@ -67,6 +73,8 @@ def quantize_model_with_single_config(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
Expand All @@ -89,5 +97,5 @@ def quantize_model_with_single_config(
for algo_name, algo_func in algos_mapping.items():
if need_apply(configs_mapping, algo_name):
logger.info(f"Start to apply {algo_name} on the model.")
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration)
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration, calib_func)
return q_model
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _inference(self, model):
Args:
model(TensorflowBaseModel): input TensorflowBaseModel
"""
if self.calib_func:
self.calib_func(model)
return

if model.model_type == "llm_saved_model":
self._inference_llm(model)
return
Expand Down
Loading