@@ -91,7 +91,7 @@ def export_formats():
91
91
['TensorFlow Lite' , 'tflite' , '.tflite' , True , False ],
92
92
['TensorFlow Edge TPU' , 'edgetpu' , '_edgetpu.tflite' , False , False ],
93
93
['TensorFlow.js' , 'tfjs' , '_web_model' , False , False ],
94
- ['PaddlePaddle' , 'paddle' , '_paddle_model' , True , True ],]
94
+ ['PaddlePaddle' , 'paddle' , '_paddle_model' , True , True ], ]
95
95
return pd .DataFrame (x , columns = ['Format' , 'Argument' , 'Suffix' , 'CPU' , 'GPU' ])
96
96
97
97
@@ -185,6 +185,66 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
185
185
return f , model_onnx
186
186
187
187
188
+ @try_export
189
+ def export_onnx_for_backend (model , im , file , opset , nms_cfg , dynamic , simplify , prefix = colorstr ('ONNX:' )):
190
+ # YOLOv5 ONNX export
191
+ check_requirements (('onnx' ,))
192
+ import onnx
193
+
194
+ LOGGER .info (f'\n { prefix } starting export with onnx { onnx .__version__ } ...' )
195
+ f = file .with_suffix ('.onnx' )
196
+
197
+ from models .common import End2End
198
+ model = End2End (model , * nms_cfg , device = im .device )
199
+ if nms_cfg [- 1 ] == 'ort' :
200
+ output_names = ['outputs' ]
201
+ elif nms_cfg [- 1 ] == 'trt' :
202
+ output_names = ['num_dets' , 'det_boxes' , 'det_scores' , 'det_classes' ]
203
+
204
+ if dynamic and nms_cfg [- 1 ] == 'ort' :
205
+ dynamic_cfg = {n : {0 : 'batch' } for n in output_names }
206
+ elif dynamic and nms_cfg [- 1 ] == 'trt' :
207
+ dynamic_cfg = {n : {0 : 'batch' } for n in output_names }
208
+
209
+ torch .onnx .export (
210
+ model .cpu () if dynamic else model , # --dynamic only compatible with cpu
211
+ im .cpu () if dynamic else im ,
212
+ f ,
213
+ verbose = False ,
214
+ opset_version = opset ,
215
+ training = torch .onnx .TrainingMode .EVAL ,
216
+ do_constant_folding = True ,
217
+ input_names = ['images' ],
218
+ output_names = output_names ,
219
+ dynamic_axes = dynamic_cfg if dynamic else None )
220
+
221
+ # Checks
222
+ model_onnx = onnx .load (f ) # load onnx model
223
+ onnx .checker .check_model (model_onnx ) # check onnx model
224
+
225
+ # Metadata
226
+ d = {'stride' : int (max (model .stride )), 'names' : model .names }
227
+ for k , v in d .items ():
228
+ meta = model_onnx .metadata_props .add ()
229
+ meta .key , meta .value = k , str (v )
230
+ onnx .save (model_onnx , f )
231
+
232
+ # Simplify
233
+ if simplify :
234
+ try :
235
+ cuda = torch .cuda .is_available ()
236
+ check_requirements (('onnxruntime-gpu' if cuda else 'onnxruntime' , 'onnx-simplifier>=0.4.1' ))
237
+ import onnxsim
238
+
239
+ LOGGER .info (f'{ prefix } simplifying with onnx-simplifier { onnxsim .__version__ } ...' )
240
+ model_onnx , check = onnxsim .simplify (model_onnx )
241
+ assert check , 'assert check failed'
242
+ onnx .save (model_onnx , f )
243
+ except Exception as e :
244
+ LOGGER .info (f'{ prefix } simplifier failure: { e } ' )
245
+ return f , model_onnx
246
+
247
+
188
248
@try_export
189
249
def export_openvino (file , metadata , half , prefix = colorstr ('OpenVINO:' )):
190
250
# YOLOv5 OpenVINO export
@@ -447,9 +507,9 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
447
507
r'"Identity.?.?": {"name": "Identity.?.?"}, '
448
508
r'"Identity.?.?": {"name": "Identity.?.?"}, '
449
509
r'"Identity.?.?": {"name": "Identity.?.?"}}}' , r'{"outputs": {"Identity": {"name": "Identity"}, '
450
- r'"Identity_1": {"name": "Identity_1"}, '
451
- r'"Identity_2": {"name": "Identity_2"}, '
452
- r'"Identity_3": {"name": "Identity_3"}}}' , json )
510
+ r'"Identity_1": {"name": "Identity_1"}, '
511
+ r'"Identity_2": {"name": "Identity_2"}, '
512
+ r'"Identity_3": {"name": "Identity_3"}}}' , json )
453
513
j .write (subst )
454
514
return f , None
455
515
@@ -506,6 +566,7 @@ def run(
506
566
verbose = False , # TensorRT: verbose log
507
567
workspace = 4 , # TensorRT: workspace size (GB)
508
568
nms = False , # TF: add NMS to model
569
+ backend = 'ort' , # Backend for export NMS
509
570
agnostic_nms = False , # TF: add agnostic NMS to model
510
571
topk_per_class = 100 , # TF.js NMS: topk per class to keep
511
572
topk_all = 100 , # TF.js NMS: topk for all classes to keep
@@ -518,6 +579,7 @@ def run(
518
579
flags = [x in include for x in fmts ]
519
580
assert sum (flags ) == len (include ), f'ERROR: Invalid --include { include } , valid --include arguments are { fmts } '
520
581
jit , onnx , xml , engine , coreml , saved_model , pb , tflite , edgetpu , tfjs , paddle = flags # export booleans
582
+ end2end , onnx = onnx and nms , onnx and not nms
521
583
file = Path (url2file (weights ) if str (weights ).startswith (('http:/' , 'https:/' )) else weights ) # PyTorch weights
522
584
523
585
# Load PyTorch model
@@ -554,7 +616,7 @@ def run(
554
616
LOGGER .info (f"\n { colorstr ('PyTorch:' )} starting from { file } with output shape { shape } ({ file_size (file ):.1f} MB)" )
555
617
556
618
# Exports
557
- f = ['' ] * len (fmts ) # exported filenames
619
+ f = ['' ] * ( len (fmts )) + 1 # exported filenames
558
620
warnings .filterwarnings (action = 'ignore' , category = torch .jit .TracerWarning ) # suppress TracerWarning
559
621
if jit : # TorchScript
560
622
f [0 ], _ = export_torchscript (model , im , file , optimize )
@@ -592,6 +654,9 @@ def run(
592
654
if paddle : # PaddlePaddle
593
655
f [10 ], _ = export_paddle (model , im , file , metadata )
594
656
657
+ if end2end :
658
+ nms_cfg = [topk_all , iou_thres , conf_thres , backend ]
659
+ f [10 ], _ = export_onnx_for_backend (model , im , file , opset , nms_cfg , dynamic , simplify )
595
660
# Finish
596
661
f = [str (x ) for x in f if x ] # filter out '' and None
597
662
if any (f ):
@@ -628,6 +693,7 @@ def parse_opt():
628
693
parser .add_argument ('--verbose' , action = 'store_true' , help = 'TensorRT: verbose log' )
629
694
parser .add_argument ('--workspace' , type = int , default = 4 , help = 'TensorRT: workspace size (GB)' )
630
695
parser .add_argument ('--nms' , action = 'store_true' , help = 'TF: add NMS to model' )
696
+ parser .add_argument ('--backend' , type = str , default = 'ort' , help = 'Backend for export NMS' )
631
697
parser .add_argument ('--agnostic-nms' , action = 'store_true' , help = 'TF: add agnostic NMS to model' )
632
698
parser .add_argument ('--topk-per-class' , type = int , default = 100 , help = 'TF.js NMS: topk per class to keep' )
633
699
parser .add_argument ('--topk-all' , type = int , default = 100 , help = 'TF.js NMS: topk for all classes to keep' )
0 commit comments