Skip to content

Commit addc9a1

Browse files
Add quantization on-the-fly feature
1 parent 91e0931 commit addc9a1

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

tf2mplabh3/main.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,16 @@ def main():
6767
'--overwrite', action='store_true',
6868
help="Overwrite existing ONNX or C model files. By default, existing files are not overwritten."
6969
)
70-
70+
parser.add_argument(
71+
'-quant', '--int8_quantize',
72+
default=0,
73+
help="Quantize on the fly from FP32 to INT8"
74+
)
75+
parser.add_argument(
76+
'-onnx_quant', '--onnx_quant_model',
77+
default=os.path.join(PROJECT_ROOT, "examples", "model_int8.onnx"),
78+
help="Path where to store the ONNX Model File"
79+
)
7180
args = parser.parse_args()
7281
global verbosity
7382
verbosity = args.verbosity
@@ -97,6 +106,13 @@ def main():
97106
# Convert TensorFlow model to ONNX
98107
print(color_text("[MAIN] Starting Tensorflow to ONNX Conversion", "green"), flush=True)
99108
tf2onnx_converter(args.model, args.onnx_model, args.tag, args.signature_def, verbosity)
109+
onnx_model_to_convert=args.onnx_model
110+
111+
if bool(args.int8_quantize):
112+
from .onnx_quantization import quantize_and_compare_nodes
113+
print(color_text("[MAIN] Starting ONNX FP32 to ONNX INT8 Quantization", "green"), flush=True)
114+
quantize_and_compare_nodes(args.onnx_model,args.onnx_quant_model,verbosity_level=verbosity)
115+
onnx_model_to_convert=args.onnx_quant_model
100116

101117
verbose("[MAIN] Ensuring the parent directory of the C model file exists")
102118
parent_dir = os.path.dirname(args.c_model_file)
@@ -111,7 +127,7 @@ def main():
111127
t.start()
112128
with open(args.c_model_file, "w") as c_file:
113129
process = subprocess.Popen(
114-
[args.onnx2c, args.onnx_model],
130+
[args.onnx2c, onnx_model_to_convert],
115131
stdout=c_file,
116132
stderr=subprocess.PIPE,
117133
text=True

tf2mplabh3/onnx_quantization.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from onnxruntime.quantization import quantize_dynamic, QuantType
2+
import onnx
3+
from .utils import color_text
4+
5+
verbosity=0
6+
7+
def verbose(msg):
8+
global verbosity
9+
if verbosity==1:
10+
print(color_text(msg,"yellow"),flush=True)
11+
12+
def quantize_and_compare_nodes(
13+
model_input_path,
14+
model_output_path,
15+
nodes_to_exclude=None,
16+
weight_type=QuantType.QInt8,
17+
verbosity_level=0
18+
):
19+
"""
20+
Quantizes an ONNX model and prints new nodes introduced in the quantized model.
21+
22+
Args:
23+
model_input_path (str): Path to the original ONNX model.
24+
model_output_path (str): Path to save the quantized ONNX model.
25+
nodes_to_exclude (list, optional): List of node names to exclude from quantization.
26+
weight_type (QuantType, optional): Weight quantization type (default: QuantType.QInt8).
27+
verbosity_level: 0 for almost no logs, 1 for full logs
28+
29+
"""
30+
global verbosity
31+
verbosity=verbosity_level
32+
33+
def get_node_names_and_types(model_path):
34+
model = onnx.load(model_path)
35+
return [(node.name, node.op_type) for node in model.graph.node]
36+
37+
# Quantize the model
38+
quantize_dynamic(
39+
model_input=model_input_path,
40+
model_output=model_output_path,
41+
weight_type=weight_type,
42+
nodes_to_exclude=nodes_to_exclude or []
43+
)
44+
45+
# Compare nodes
46+
original_nodes = get_node_names_and_types(model_input_path)
47+
quantized_nodes = get_node_names_and_types(model_output_path)
48+
49+
original_set = set(original_nodes)
50+
quantized_set = set(quantized_nodes)
51+
new_nodes = quantized_set - original_set
52+
53+
verbose("([ONNX_QUANTIZATION] New Nodes introduced by the quantization")
54+
for name, op_type in new_nodes:
55+
verbose(f"[ONNX_QUANTIZATION] Name: {name}, OpType: {op_type}")

0 commit comments

Comments
 (0)