-
Notifications
You must be signed in to change notification settings - Fork 383
/
onnx_simplifier.py
536 lines (487 loc) · 21.9 KB
/
onnx_simplifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
import argparse
import copy
import os
import sys
import re
import tempfile
from typing import List, Dict, Union, Optional, Tuple, Sequence
from rich.text import Text
from rich import print
import numpy as np
import onnx # type: ignore
import onnx.checker # type: ignore
import onnx.helper # type: ignore
import onnx.shape_inference # type: ignore
import onnx.numpy_helper # type: ignore
try:
import onnxruntime as rt # type: ignore
except ImportError:
command = [sys.executable, '-m', 'pip', 'install', 'onnxruntime']
print(Text(f"Installing onnxruntime by `{' '.join(command)}`, please wait for a moment..", style="bold magenta"))
import subprocess
subprocess.check_call(command)
import onnxruntime as rt
import onnxsim.onnxsim_cpp2py_export as C
from . import model_info
from . import model_checking
from . import version
TensorShape = List[int]
TensorShapes = Dict[str, TensorShape]
TensorShapesWithOptionalKey = Dict[Optional[str], TensorShape]
def get_output_names(model: onnx.ModelProto) -> List[str]:
output_names = [opt.name for opt in model.graph.output]
return output_names
def remove_unused_output(
model: onnx.ModelProto, unused_output: Sequence[str]
) -> onnx.ModelProto:
unused_output_names = unused_output
output_names = get_output_names(model)
for unused_output_name in unused_output_names:
if unused_output_name not in output_names:
raise RuntimeError(
f'The model doesn\'t have output named "{unused_output_name}"'
)
for graph_output in copy.deepcopy(model.graph.output):
if graph_output.name in unused_output_names:
model.graph.output.remove(graph_output)
return model
def remove_initializer_from_input(model: onnx.ModelProto) -> onnx.ModelProto:
initializer_names = [x.name for x in model.graph.initializer]
for graph_input in copy.deepcopy(model.graph.input):
if graph_input.name in initializer_names:
model.graph.input.remove(graph_input)
return model
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: Optional[TensorShapesWithOptionalKey]) -> Optional[TensorShapes]:
if input_shapes is None:
return None
def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]:
initializer_names = [x.name for x in model.graph.initializer]
return [ipt for ipt in model.graph.input if ipt.name not in initializer_names]
def get_input_names(model: onnx.ModelProto) -> List[str]:
input_names = [ipt.name for ipt in get_inputs(model)]
return input_names
input_names = get_input_names(model)
if None in input_shapes:
if len(input_names) == 1:
input_shapes[input_names[0]] = input_shapes[None]
del input_shapes[None]
else:
raise RuntimeError(
'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')
for x in input_shapes:
if x not in input_names:
raise RuntimeError(
'The model doesn\'t have input named "{}"'.format(x))
return input_shapes # type: ignore
# A very very large threshold
DEFAULT_TENSOR_SIZE_THRESHOLDHOLD = '1.5GB'
def simplify(
model: Union[str, onnx.ModelProto],
check_n: int = 0,
perform_optimization: bool = True,
skip_fuse_bn: bool = False,
overwrite_input_shapes=None,
test_input_shapes=None,
skipped_optimizers: Optional[List[str]] = None,
skip_constant_folding=False,
skip_shape_inference=False,
input_data=None,
dynamic_input_shape: bool = False,
custom_lib: Optional[str] = None,
include_subgraph: bool = False,
unused_output: Optional[Sequence[str]] = None,
tensor_size_threshold: str = DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,
mutable_initializer: bool = False,
*,
input_shapes=None,
) -> Tuple[onnx.ModelProto, bool]:
"""
:param model: onnx ModelProto object or file path
:param check_n: The simplified model will be checked for `check_n` times by random inputs
:param perform_optimization: Whether to run onnx optimizer on the model
:param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer
:param overwrite_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
for generating random inputs and checking equality.
:param test_input_shapes: If the model has dynamic input shape, user must pass a fixed input shape
for generating random inputs and checking equality.
:param skipped_optimizers: Skip some specific onnx optimizers
:param skip_constant_folding: Skip constant folding
:param skip_shape_inference: Skip shape inference (sometimes shape inference will crash)
:param input_data: Feed custom input data for checking if needed
:param dynamic_input_shape: Deprecated. Not needed anymore.
:param custom_lib: onnxruntime custom ops's shared library
:param include_subgraph: Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph
:param unused_output: name of unused outputs that will be eliminated from the model
:param input_shapes: Deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead.
:return: A tuple (simplified model, success(True) or failed(False))
"""
if dynamic_input_shape:
print(
Text(
"WARNING: The argument `dynamic_input_shape=True` is not needed any more, onnxsim can now support dynamic input shapes natively, please refer to the latest documentation. An error will be raised in the future.",
style="bold red",
)
)
if input_shapes is not None:
print(
Text(
"WARNING: The argument `input_shapes` is deprecated. Please use `overwrite_input_shapes` and/or `test_input_shapes` instead. An error will be raised in the future.",
style="bold red",
)
)
overwrite_input_shapes = input_shapes
test_input_shapes = input_shapes
if not perform_optimization:
# None means skip all optimizers
skipped_optimizers = None
elif skipped_optimizers is None:
skipped_optimizers = []
if skip_fuse_bn and skipped_optimizers is not None:
skipped_optimizers.append("fuse_bn_into_conv")
if isinstance(model, str):
model = onnx.load(model)
if overwrite_input_shapes is None:
overwrite_input_shapes = {}
overwrite_input_shapes = check_and_update_input_shapes(
model, overwrite_input_shapes)
test_input_shapes = check_and_update_input_shapes(
model, test_input_shapes)
for name, input_shape in overwrite_input_shapes.items():
for ipt in model.graph.input:
if ipt.name == name:
for i, dim in enumerate(ipt.type.tensor_type.shape.dim):
dim.dim_value = input_shape[i]
if unused_output is not None:
model = remove_unused_output(model, unused_output)
if not mutable_initializer and model.ir_version >= 4:
model = remove_initializer_from_input(model)
# https://stackoverflow.com/a/60708339
def parse_size(size: str) -> int:
units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40}
size = size.upper()
if not re.match(r' ', size):
size = re.sub(r'([KMGT]?B)', r' \1', size)
number, unit = [string.strip() for string in size.split()]
return int(float(number)*units[unit])
tensor_size_threshold = parse_size(tensor_size_threshold)
if tensor_size_threshold > 2**31 - 9999:
raise ValueError("tensor_size_threshold should be less than 2GB")
try:
model_bytes = model.SerializeToString()
model_opt_bytes = C.simplify(
model_bytes,
skipped_optimizers,
not skip_constant_folding,
not skip_shape_inference,
tensor_size_threshold,
)
if len(model_opt_bytes) == 0:
raise ValueError("Simplified model larger than 2GB")
model_opt = onnx.load_from_string(model_opt_bytes)
check_ok = model_checking.compare(
model_opt, model, check_n, test_input_shapes, input_data, custom_lib
)
except (ValueError, onnx.onnx_cpp2py_export.checker.ValidationError):
print("[bold magenta]Simplified model larger than 2GB. Trying to save as external data...[/bold magenta]")
# large models try to convert through a temporary file
with tempfile.TemporaryDirectory() as tmpdirname:
onnx.save(
copy.deepcopy(model),
os.path.join(tmpdirname, 'model.onnx'),
save_as_external_data=True,
)
check_ok = C.simplify_path(
os.path.join(tmpdirname, 'model.onnx'),
os.path.join(tmpdirname, 'opt.onnx'),
skipped_optimizers,
not skip_constant_folding,
not skip_shape_inference,
tensor_size_threshold,
)
check_ok = model_checking.compare(
os.path.join(tmpdirname, 'opt.onnx'),
os.path.join(tmpdirname, 'model.onnx'),
check_n, test_input_shapes, input_data, custom_lib
)
model_opt = onnx.load(os.path.join(tmpdirname, 'opt.onnx'))
return model_opt, check_ok
class PyModelExecutor(C.ModelExecutor):
def Run(self, model_str: str, inputs_str: List[str]):
model = onnx.ModelProto()
model.ParseFromString(model_str)
def deserialize_tp(tp_str):
tp = onnx.TensorProto()
tp.ParseFromString(tp_str)
return tp
input_tps = map(deserialize_tp, inputs_str)
input_arrs = map(onnx.numpy_helper.to_array, input_tps)
input_names = [x.name for x in model.graph.input]
inputs = dict(zip(input_names, input_arrs))
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0)
sess_options.log_severity_level = 3
sess = rt.InferenceSession(
model.SerializeToString(),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
output_names = [x.name for x in sess.get_outputs()]
run_options = rt.RunOptions()
run_options.log_severity_level = 3
output_arrs = sess.run(output_names, inputs, run_options=run_options)
return [
onnx.numpy_helper.from_array(x).SerializeToString() for x in output_arrs
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Input ONNX model")
parser.add_argument("output_model", help="Output ONNX model")
parser.add_argument(
"check_n",
help="Check whether the output is correct with n random inputs",
nargs="?",
type=int,
default=0,
)
parser.add_argument(
"--enable-fuse-bn",
help="This option is deprecated. Fusing bn into conv is enabled by default.",
action="store_true",
)
parser.add_argument(
"--skip-fuse-bn", help="Skip fusing batchnorm into conv.", action="store_true"
)
parser.add_argument(
"--skip-optimization",
help="Skip all ONNX optimizers or some of them. To skip all optimizers, use `onnxsim a.onnx b.onnx --skip-optimization`. To skip some of optimizers, use something like `onnxsim a.onnx b.onnx --skip-optimization fuse_bn_into_conv fuse_pad_into_pool`.",
type=str,
nargs="*",
)
parser.add_argument("--skip-constant-folding", help="Skip constant folding", action="store_true")
parser.add_argument(
"--input-shape",
help="This argument has been renamed to --overwrite-input-shape, please refer to it",
type=str,
nargs="+",
)
parser.add_argument(
"--overwrite-input-shape",
help='Overwrite the input shape. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',
type=str,
nargs="+",
)
parser.add_argument(
"--test-input-shape",
help='The input shape to generated random inputs for test, useful when the input shape is dynamic. The format is "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.',
type=str,
nargs="+",
)
parser.add_argument(
"--skip-optimizer",
help="Deprecated. Refer to --skip-optimization",
type=str,
nargs="+",
)
parser.add_argument(
"--skip-shape-inference", help="Skip shape inference", action="store_true"
)
parser.add_argument(
"--enable-onnxruntime-optimization",
help="Enable ONNX Runtime's ORT_ENABLE_BASIC level optimization.",
action="store_true",
)
parser.add_argument(
"--dynamic-input-shape",
help="Deprecated. Not needed any more.",
action="store_true",
)
parser.add_argument(
"--input-data-path",
help='input data, The value should be "input_name1:xxx1.bin" "input_name2:xxx2.bin ...", input data should be a binary data file.',
type=str,
nargs="+",
)
parser.add_argument(
"--custom-lib", help="Deprecated. Not needed any more.", type=str
)
parser.add_argument(
"--include-subgraph",
help='Experimental feature. Simplify subgraph (e.g. true graph and false graph of "If" operator) instead of only the main graph',
action="store_true",
)
parser.add_argument(
"--unused-output",
help="Name of unused outputs that will be eliminated from the model",
type=str,
nargs="+",
)
parser.add_argument(
"--no-large-tensor",
help="Some ops like Tile and ConstantOfShape can produce large tensor and make the model size much larger. Specifying this flag to skip folding these ops, with loss of some optimization chances. It can be followed with a threshold, for example, --no-large-tensor 1M or --no-large-tensor 100KB. A simple '--no-large-tensor' means '--no-large-tensor 1KB'.",
type=str,
const='1KB',
default=DEFAULT_TENSOR_SIZE_THRESHOLDHOLD,
nargs="?",
dest="tensor_size_threshold",
)
parser.add_argument(
"--mutable-initializer",
help="By ONNX specification, initializers can also serve as inputs. This allows users to overwrite their values during runtime, but some useful optimizations like fuse-conv-and-bn will not be applicable anymore. In almost all cases, having an initializer that is also an input is unintended (usually caused by a out-dated PyTorch). So onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, you can specify '--mutable-initializer' to disable this behavior.",
action="store_true",
)
parser.add_argument(
"--save-as-external-data",
help="Save parameters as external data. This will make the .onnx file much smaller, but the .onnx file will depend on the external data file (.data).",
action="store_true",
)
parser.add_argument('-v', '--version', action='version', version='onnxsim ' + version.version)
args = parser.parse_args()
if args.enable_fuse_bn:
print(
Text(
'WARNING: "--enable-fuse-bn" is not needed any more, because fuse bn is enabled by default. "--enable-fuse-bn" flag is ignored now and will raise an error in the future.',
style="bold red",
)
)
if args.dynamic_input_shape:
print(
Text(
'WARNING: "--dynamic-input-shape" is not needed any more, onnxsim v0.4 now handles dynamic input shapes automatically. "--dynamic-input-shape" flag is ignored now and will raise an error in the future.',
style="bold red",
)
)
assert not (args.input_shape is not None and args.overwrite_input_shape is not None)
if args.input_shape:
print(
Text(
'WARNING: "--input-shape" is renamed to "--overwrite-input-shape". Please use it instead.',
style="bold red",
)
)
args.overwrite_input_shape = args.input_shape
if args.include_subgraph:
print(
Text(
"WARNING: subgraph optimization is not supported in v0.4 for now.",
style="bold red",
)
)
assert not (args.skip_optimizer is not None and args.skip_optimization is not None)
if args.skip_optimizer:
print(
Text(
'WARNING: "--skip-optimizer" is renamed to "--skip-optimization". Please use it instead.',
style="bold red",
)
)
args.skip_optimization = args.skip_optimizer
if args.skip_optimization is None:
# user doesn't specify --skip-optimization
args.skip_optimization = []
elif len(args.skip_optimization) == 0:
# user specify --skip-optimization without any certain optimizer name
# set it to None means skip all optimizations
args.skip_optimization = None
if args.skip_fuse_bn and args.skip_optimization is not None:
args.skip_optimization.append("fuse_bn_into_conv")
perform_optimization = False if args.skip_optimization is None else True
def parse_shapes(shapes_arg):
shapes = {}
if shapes_arg is not None:
for x in shapes_arg:
if ':' not in x:
shapes[None] = list(map(int, x.split(',')))
else:
pieces = x.split(':')
# for the input name like input:0
name, shape = ':'.join(
pieces[:-1]), list(map(int, pieces[-1].split(',')))
shapes.update({name: shape})
return shapes
test_input_shapes = parse_shapes(args.test_input_shape)
overwrite_input_shapes = parse_shapes(args.overwrite_input_shape)
if args.enable_onnxruntime_optimization:
tmp_file = tempfile.NamedTemporaryFile()
sess_options = rt.SessionOptions()
# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_BASIC
# To enable model serialization after graph optimization
sess_options.optimized_model_filepath = tmp_file.name
_ = rt.InferenceSession(args.input_model, sess_options, providers=["CPUExecutionProvider"])
model = onnx.load(tmp_file.name)
else:
model = onnx.load(args.input_model)
if args.tensor_size_threshold == DEFAULT_TENSOR_SIZE_THRESHOLDHOLD:
for node in model.graph.node:
if node.op_type in ["Tile", "ConstantOfShape"]:
print(
Text(
'Your model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the simplified model much larger. If it is not expected, please specify "--no-large-tensor" (which will lose some optimization chances)',
style="bold magenta",
)
)
break
if not args.mutable_initializer:
initializer_names = set([x.name for x in model.graph.initializer])
input_names = set([x.name for x in model.graph.input])
if len(initializer_names.intersection(input_names)) > 0:
print(
Text(
'Your model contains initializers that are also inputs. This is usually caused by an out-dated PyTorch. onnxsim treats all initializers immutable to enabling all optimizations. If it is not wanted, please specify "--mutable-initializer" to disable this behavior.',
style="bold magenta",
)
)
input_tensors = None
if args.input_data_path is not None:
input_tensors = {}
for x in args.input_data_path:
pieces = x.split(':')
name, data = ':'.join(pieces[:-1]), pieces[-1]
input_tensors.update({name: np.load(data)})
print("Simplifying...")
model_opt, check_ok = simplify(
model,
args.check_n,
perform_optimization,
False,
overwrite_input_shapes,
test_input_shapes,
args.skip_optimization,
args.skip_constant_folding,
args.skip_shape_inference,
input_tensors,
False,
args.custom_lib,
args.include_subgraph,
args.unused_output,
args.tensor_size_threshold,
args.mutable_initializer,
)
try:
if not args.save_as_external_data:
onnx.save(model_opt, args.output_model)
else:
raise ValueError("save_as_external_data")
except ValueError:
# large models (>2GB) which onnx.save doesn't support,
# or explicitly specified --save-as-external-data
external_data_path = os.path.basename(args.output_model) + '.data'
if os.path.exists(external_data_path):
os.remove(external_data_path)
onnx.save(
copy.deepcopy(model_opt),
args.output_model,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_path,
)
if check_ok:
print("Finish! Here is the difference:")
model_info.print_simplifying_info(model, model_opt)
else:
print(
'Check failed. Please be careful to use the simplified model, or try specifying "--skip-fuse-bn" or "--skip-optimization" (run "onnxsim -h" for details).'
)
print("Here is the difference after simplification:")
model_info.print_simplifying_info(model, model_opt)
sys.exit(1)