Skip to content

Commit 54134bf

Browse files
authored
Migrate onnxrt GPTQ and AWQ WOQ to 3.x API (#1570)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent efea089 commit 54134bf

File tree

20 files changed

+2233
-347
lines changed

20 files changed

+2233
-347
lines changed

neural_compressor/onnxrt/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@
1717
from neural_compressor.onnxrt.utils.utility import register_algo
1818
from neural_compressor.onnxrt.quantization import (
1919
rtn_quantize_entry,
20-
smooth_quant_entry,
2120
RTNConfig,
2221
get_default_rtn_config,
22+
gptq_quantize_entry,
23+
GPTQConfig,
24+
get_default_gptq_config,
25+
awq_quantize_entry,
26+
AWQConfig,
27+
get_default_awq_config,
28+
smooth_quant_entry,
2329
SmoohQuantConfig,
2430
get_default_sq_config,
2531
CalibrationDataReader,
@@ -28,9 +34,15 @@
2834
__all__ = [
2935
"register_algo",
3036
"rtn_quantize_entry",
31-
"smooth_quant_entry",
3237
"RTNConfig",
3338
"get_default_rtn_config",
39+
"gptq_quantize_entry",
40+
"GPTQConfig",
41+
"get_default_gptq_config",
42+
"awq_quantize_entry",
43+
"AWQConfig",
44+
"get_default_awq_config",
45+
"smooth_quant_entry",
3446
"SmoohQuantConfig",
3547
"get_default_sq_config",
3648
"CalibrationDataReader",

neural_compressor/onnxrt/algorithms/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
from neural_compressor.onnxrt.algorithms.smoother import Smoother
1717
from neural_compressor.onnxrt.algorithms.weight_only.rtn import apply_rtn_on_model
18+
from neural_compressor.onnxrt.algorithms.weight_only.gptq import apply_gptq_on_model
19+
from neural_compressor.onnxrt.algorithms.weight_only.awq import apply_awq_on_model
1820

19-
__all__ = [
20-
"Smoother",
21-
"apply_rtn_on_model",
22-
]
21+
__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model"]

neural_compressor/onnxrt/algorithms/smoother/calibrator.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,43 @@
1717
import tempfile
1818
from importlib.util import find_spec
1919
from pathlib import Path
20+
from typing import List
2021

2122
import numpy as np
2223
import onnx
2324
import onnx.numpy_helper as numpy_helper
2425
import onnxruntime
2526

2627
from neural_compressor.common import Logger
28+
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
29+
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel
2730

2831
logger = Logger().get_logger()
2932

33+
__all__ = ["Calibrator"]
34+
3035

3136
class Calibrator:
3237
"""Dump information for smooth quant."""
3338

3439
def __init__(
3540
self,
36-
model,
37-
dataloader,
38-
dump_op_types,
39-
iterations=[],
40-
providers=["CPUExecutionProvider"],
41+
model: ONNXModel,
42+
dataloader: CalibrationDataReader,
43+
iterations: List[int] = [],
44+
providers: List[str] = ["CPUExecutionProvider"],
4145
**kwargs,
4246
):
43-
"""Initialization.
47+
"""Initialize a Calibrator to dump information.
4448
4549
Args:
46-
model (object): ONNXModel object
47-
dataloader (object): user implemented object to read in and preprocess calibration dataset
48-
dump_op_types (list): operator types to be calibrated and quantized
49-
iterations (list, optional): tensor of which iteration will be collected. Defaults to [].
50-
providers (list, optional): execution provider for onnxruntime. Defaults to ['CPUExecutionProvider'].
50+
model (ONNXModel): ONNXModel object.
51+
dataloader (CalibrationDataReader): user implemented object to read in and preprocess calibration dataset.
52+
iterations (List[int], optional): tensor of which iteration will be collected. Defaults to [].
53+
providers (List[str], optional): execution provider for onnxruntime. Defaults to ["CPUExecutionProvider"].
5154
"""
5255
self.model_wrapper = model
5356
self.dataloader = dataloader
54-
self.dump_op_types = dump_op_types
5557
self.augmented_model = None
5658
self.iterations = iterations
5759
self.providers = providers
@@ -87,17 +89,18 @@ def _check_is_group_conv(self, node):
8789
return True
8890
return False
8991

90-
def _get_input_tensor_of_ops(self, op_types=["MatMul", "Gemm", "Conv", "FusedConv"]):
92+
def _get_input_tensor_of_ops(self, op_types: List[str] = ["MatMul", "Gemm", "Conv", "FusedConv"]):
9193
"""Traverse the graph and get all the data tensors flowing into layers of {op_types}.
9294
9395
Group conv is excluded.
9496
# TODO: the tensors could be set/filtered in configuration.
9597
9698
Args:
97-
op_types: The op types whose input tensor will be dumped
99+
op_types (List[str], optional): The op types whose input tensor will be dumped.
100+
Defaults to ["MatMul", "Gemm", "Conv", "FusedConv"].
98101
99102
Returns:
100-
A dict of dumped tensor: node info
103+
dict: A dict of dumped tensor to node info
101104
"""
102105
tensors_to_node = {}
103106
initializers = {i.name: i for i in self.model_wrapper.initializer()}
@@ -111,7 +114,7 @@ def _get_input_tensor_of_ops(self, op_types=["MatMul", "Gemm", "Conv", "FusedCon
111114
tensors_to_node.setdefault(node.input[0], []).append([node.name, node.input, node.output])
112115
return tensors_to_node
113116

114-
def _get_max_per_channel(self, datas: list, percentile):
117+
def _get_max_per_channel(self, datas, percentile):
115118
"""Get the max values per input channel.
116119
117120
Args:
@@ -200,14 +203,15 @@ def _collect_data(ort_inputs):
200203
idx += 1
201204
return output_dicts
202205

203-
def calib_smooth(self, op_types, percentile=99.999):
206+
def calib_smooth(self, op_types, percentile: float = 99.999):
204207
"""Smooth model calibration.
205208
206209
Mainly get the max info per channel of input tensors.
207210
208211
Args:
209-
percentile:Percentile of calibration to remove outliers
210-
op_types: The op types whose input tensor will be dumped
212+
op_types (_type_): The op types whose input tensor will be dumped.
213+
percentile (float, optional): Percentile of calibration to remove outliers.
214+
Defaults to 99.999.
211215
212216
Returns:
213217
max_vals_per_channel: max values per channel of input tensors

neural_compressor/onnxrt/algorithms/smoother/core.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#
2-
# -*- coding: utf-8 -*-
3-
#
41
# Copyright (c) 2023 Intel Corporation
52
#
63
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,6 +15,8 @@
1815

1916
import copy
2017
import os
18+
from pathlib import Path
19+
from typing import List, Union
2120

2221
import numpy as np
2322
import onnx
@@ -26,6 +25,7 @@
2625

2726
from neural_compressor.common import Logger
2827
from neural_compressor.onnxrt.algorithms.smoother.calibrator import Calibrator
28+
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
2929
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel
3030
from neural_compressor.onnxrt.utils.utility import (
3131
get_qrange_for_qType,
@@ -36,7 +36,9 @@
3636

3737
logger = Logger().get_logger()
3838

39-
dtype_map = {
39+
__all__ = ["Smoother"]
40+
41+
_dtype_map = {
4042
np.dtype("float32"): 1,
4143
np.dtype("uint8"): 2,
4244
np.dtype("int8"): 3,
@@ -47,7 +49,7 @@
4749
}
4850

4951

50-
def get_quant_dequant_output(model, input_data, output_data, providers):
52+
def _get_quant_dequant_output(model, input_data, output_data, providers):
5153
"""Get loss between fp32 output and QDQ output.
5254
5355
Args:
@@ -58,14 +60,14 @@ def get_quant_dequant_output(model, input_data, output_data, providers):
5860
"""
5961
import onnxruntime as ort
6062

61-
input_data = quant_dequant_data(input_data, 2, "asym")
63+
input_data = _quant_dequant_data(input_data, 2, "asym")
6264
sess = ort.InferenceSession(model.SerializeToString(), providers=providers)
6365
preds = sess.run(None, {model.graph.input[0].name: input_data})
6466
loss = np.sum(np.abs(output_data - preds) ** 2)
6567
return loss
6668

6769

68-
def make_sub_graph(node, inits, input_data, output_data, opset, ir_version):
70+
def _make_sub_graph(node, inits, input_data, output_data, opset, ir_version):
6971
"""Build a model with the specific node.
7072
7173
Args:
@@ -78,15 +80,15 @@ def make_sub_graph(node, inits, input_data, output_data, opset, ir_version):
7880
"""
7981
from onnx import helper
8082

81-
input = helper.make_tensor_value_info(node.input[0], dtype_map[input_data.dtype], input_data.shape)
82-
output = helper.make_tensor_value_info(node.output[0], dtype_map[output_data.dtype], output_data.shape)
83+
input = helper.make_tensor_value_info(node.input[0], _dtype_map[input_data.dtype], input_data.shape)
84+
output = helper.make_tensor_value_info(node.output[0], _dtype_map[output_data.dtype], output_data.shape)
8385
graph = helper.make_graph([node], "sub_graph", [input], [output], inits)
8486
model = helper.make_model(graph, opset_imports=opset)
8587
model.ir_version = ir_version
8688
return model
8789

8890

89-
def quant_dequant_data(data, qType=3, scheme="sym"):
91+
def _quant_dequant_data(data, qType=3, scheme="sym"):
9092
"""Quantize and then dequantize data.
9193
9294
Args:
@@ -113,9 +115,9 @@ class Smoother:
113115

114116
def __init__(
115117
self,
116-
model,
117-
dataloader,
118-
providers=["CPUExecutionProvider"],
118+
model: Union[onnx.ModelProto, ONNXModel, Path, str],
119+
dataloader: CalibrationDataReader,
120+
providers: List[str] = ["CPUExecutionProvider"],
119121
):
120122
"""Initialize the attributes of class."""
121123
self.model = model if isinstance(model, ONNXModel) else ONNXModel(model, load_external_data=True)
@@ -138,30 +140,37 @@ def __init__(
138140

139141
def transform(
140142
self,
141-
alpha=0.5,
142-
folding=True,
143-
percentile=99.999,
144-
op_types=["Gemm", "Conv", "MatMul", "FusedConv"],
145-
scales_per_op=True,
146-
calib_iter=100,
147-
auto_alpha_args={"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"},
143+
alpha: Union[float, str] = 0.5,
144+
folding: bool = True,
145+
percentile: float = 99.999,
146+
op_types: List[str] = ["Gemm", "Conv", "MatMul", "FusedConv"],
147+
scales_per_op: bool = True,
148+
calib_iter: int = 100,
149+
auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"},
148150
*args,
149151
**kwargs
150152
):
151153
"""The main entry of smooth quant.
152154
153155
Args:
154-
alpha (float or str): alpha value to balance the quantization difficulty of activation and weight.
155-
folding (bool): whether fold those foldable Mul which are inserted for smooth quant
156-
percentile (float): percentile of calibration to remove outliers
157-
op_types (list): the op type to be smooth quantized
158-
scales_per_op (bool): True, each op will have an individual scale, mainlyfor accuracy
159-
False, ops with the same input will share a scale, mainly for performance
160-
calib_iter (int): iteration num for calibration
156+
alpha (float, optional): alpha value to balance the quantization difficulty of activation and weight.
157+
Defaults to 0.5.
158+
folding (bool, optional): whether fold those foldable Mul which are inserted for smooth quant.
159+
Defaults to True.
160+
percentile (float, optional): percentile of calibration to remove outliers.
161+
Defaults to 99.999.
162+
op_types (list, optional): the op type to be smooth quantized.
163+
Defaults to ["Gemm", "Conv", "MatMul", "FusedConv"].
164+
scales_per_op (bool, optional): True, each op will have an individual scale, mainlyfor accuracy
165+
False, ops with the same input will share a scale, mainly for performance.
166+
Defaults to True.
167+
calib_iter (int, optional): iteration num for calibration. Defaults to 100.
168+
auto_alpha_args (_type_, optional): alpha args for auto smooth.
169+
Defaults to {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"}.
161170
162171
Returns:
163-
A FP32 model with the same architecture as the orig model but with different weight which will be
164-
benefit to quantization
172+
onnx.ModelProto: A FP32 model with the same architecture as the orig model
173+
but with different weight which will be benefit to quantization
165174
"""
166175
self.scales_per_op = scales_per_op
167176
self.clean()
@@ -207,7 +216,6 @@ def _dump_op_info(self, percentile, op_types, iterations):
207216
calibrator = Calibrator(
208217
self.model,
209218
self.dataloader,
210-
[],
211219
iterations=list(range(0, iterations)),
212220
backend=self.providers,
213221
)
@@ -388,7 +396,7 @@ def _get_output_loss(self, node_name, scale, calib_iter):
388396
)
389397
base_dir = "" if not self.model.is_large_model else os.path.dirname(self.model.model_path)
390398
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
391-
weight_q = quant_dequant_data(weight)
399+
weight_q = _quant_dequant_data(weight)
392400

393401
self.model.set_initializer(node.input[1], weight_q)
394402
inits = [self.model.get_initializer(i) for i in node.input if self.model.get_initializer(i) is not None]
@@ -404,15 +412,15 @@ def _get_output_loss(self, node_name, scale, calib_iter):
404412

405413
outputs = session.run(added_tensors, inputs)
406414
if model is None:
407-
model = make_sub_graph(
415+
model = _make_sub_graph(
408416
node,
409417
inits,
410418
outputs[0],
411419
outputs[1],
412420
self.model.model.opset_import,
413421
self.model.model.ir_version,
414422
)
415-
loss += get_quant_dequant_output(model, outputs[0] * scale, outputs[1], self.providers)
423+
loss += _get_quant_dequant_output(model, outputs[0] * scale, outputs[1], self.providers)
416424

417425
self.model.remove_tensors_from_outputs([i for i in added_tensors if i not in orig_outputs])
418426
self.model.set_initializer(node.input[1], weight)
@@ -431,7 +439,14 @@ def _reshape_scale_for_input(self, tensor, key):
431439
scale = np.reshape(self.tensor_scales_info[key], (1, self.tensor_scales_info[key].shape[0]))
432440
return scale
433441

434-
def _auto_tune_alpha(self, calib_iter, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, attn_method="min"):
442+
def _auto_tune_alpha(
443+
self,
444+
calib_iter,
445+
alpha_min: float = 0.3,
446+
alpha_max: float = 0.7,
447+
alpha_step: float = 0.05,
448+
attn_method: str = "min",
449+
):
435450
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
436451
437452
Args:

0 commit comments

Comments
 (0)