Skip to content

Commit 303b9ac

Browse files
committed
This is a combination of 5 commits.
New PR for "ultralytics#7736" Remove not use Format onnxruntime and tensorrt onnx outputs fix unified outputs
1 parent 65071da commit 303b9ac

File tree

2 files changed

+257
-5
lines changed

2 files changed

+257
-5
lines changed

export.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,74 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
185185
return f, model_onnx
186186

187187

188+
@try_export
189+
def export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify, prefix=colorstr('ONNX:')):
190+
# YOLOv5 ONNX export
191+
check_requirements('onnx>=1.12.0')
192+
import onnx
193+
194+
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
195+
f = file.with_suffix('.onnx')
196+
197+
from models.common import End2End
198+
model = End2End(model, *nms_cfg, device=im.device)
199+
b, topk, backend = 'batch', nms_cfg[0], nms_cfg[-1]
200+
output_names = ['num_dets', 'boxes', 'scores', 'labels']
201+
output_shapes = {n: {0: b} for n in output_names}
202+
if dynamic == 'batch':
203+
dynamic_cfg = {'images': {0: b}, **output_shapes}
204+
elif dynamic == 'all':
205+
dynamic_cfg = {'images': {0: b, 2: 'height', 3: 'width'}, **output_shapes}
206+
else:
207+
dynamic_cfg, b = {}, im.shape[0]
208+
209+
torch.onnx.export(
210+
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
211+
im.cpu() if dynamic else im,
212+
f,
213+
verbose=False,
214+
opset_version=opset,
215+
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
216+
input_names=['images'],
217+
output_names=output_names,
218+
dynamic_axes=dynamic_cfg)
219+
220+
# Checks
221+
model_onnx = onnx.load(f) # load onnx model
222+
onnx.checker.check_model(model_onnx) # check onnx model
223+
224+
# Metadata
225+
d = {'stride': int(max(model.stride)), 'names': model.names}
226+
for k, v in d.items():
227+
meta = model_onnx.metadata_props.add()
228+
meta.key, meta.value = k, str(v)
229+
230+
# Fix shape info for onnx using by TensorRT
231+
if backend == 'trt':
232+
shapes = [b, 1, b, topk, 4, b, topk, b, topk]
233+
else:
234+
shapes = [b, 1, b, 'topk', 4, b, 'topk', b, 'topk']
235+
for i in model_onnx.graph.output:
236+
for j in i.type.tensor_type.shape.dim:
237+
j.dim_param = str(shapes.pop(0))
238+
onnx.save(model_onnx, f)
239+
240+
# Simplify
241+
if simplify:
242+
try:
243+
cuda = torch.cuda.is_available()
244+
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
245+
import onnxsim
246+
247+
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
248+
model_onnx, check = onnxsim.simplify(model_onnx)
249+
assert check, 'assert check failed'
250+
onnx.save(model_onnx, f)
251+
except Exception as e:
252+
LOGGER.info(f'{prefix} simplifier failure: {e}')
253+
return f, model_onnx
254+
255+
188256
@try_export
189257
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
190258
# YOLOv5 OpenVINO export
@@ -505,7 +573,7 @@ def run(
505573
opset=12, # ONNX: opset version
506574
verbose=False, # TensorRT: verbose log
507575
workspace=4, # TensorRT: workspace size (GB)
508-
nms=False, # TF: add NMS to model
576+
nms=False, # ONNX/TF/TensorRT: NMS config for model
509577
agnostic_nms=False, # TF: add agnostic NMS to model
510578
topk_per_class=100, # TF.js NMS: topk per class to keep
511579
topk_all=100, # TF.js NMS: topk for all classes to keep
@@ -560,9 +628,9 @@ def run(
560628
f[0], _ = export_torchscript(model, im, file, optimize)
561629
if engine: # TensorRT required before ONNX
562630
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
563-
if onnx or xml: # OpenVINO requires ONNX
631+
if not nms and onnx or xml: # OpenVINO requires ONNX
564632
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
565-
if xml: # OpenVINO
633+
if not nms and xml: # OpenVINO
566634
f[3], _ = export_openvino(file, metadata, half)
567635
if coreml: # CoreML
568636
f[4], _ = export_coreml(model, im, file, int8, half)
@@ -592,6 +660,11 @@ def run(
592660
if paddle: # PaddlePaddle
593661
f[10], _ = export_paddle(model, im, file, metadata)
594662

663+
if nms and (onnx or xml):
664+
nms_cfg = [topk_all, iou_thres, conf_thres, nms]
665+
f.append(export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify)[0])
666+
if xml:
667+
f.append(export_openvino(file.with_suffix('.pt'), metadata, half)[0])
595668
# Finish
596669
f = [str(x) for x in f if x] # filter out '' and None
597670
if any(f):
@@ -622,12 +695,12 @@ def parse_opt():
622695
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
623696
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
624697
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
625-
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
698+
parser.add_argument('--dynamic', nargs='?', const='all', default=False, help='ONNX/TF/TensorRT: dynamic axes')
626699
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
627700
parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')
628701
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
629702
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
630-
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
703+
parser.add_argument('--nms', nargs='?', const=True, default=False, help='ONNX/TF/TensorRT: NMS config for model')
631704
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
632705
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
633706
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')

models/common.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import math
1010
import platform
11+
import random
1112
import warnings
1213
import zipfile
1314
from collections import OrderedDict, namedtuple
@@ -858,3 +859,181 @@ def forward(self, x):
858859
if isinstance(x, list):
859860
x = torch.cat(x, 1)
860861
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
862+
863+
864+
class ORT_NMS(torch.autograd.Function):
865+
866+
@staticmethod
867+
def forward(ctx,
868+
boxes,
869+
scores,
870+
max_output_boxes_per_class=torch.tensor([100]),
871+
iou_threshold=torch.tensor([0.45]),
872+
score_threshold=torch.tensor([0.25])):
873+
device = boxes.device
874+
batch = scores.shape[0]
875+
num_det = random.randint(0, 100)
876+
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
877+
idxs = torch.arange(100, 100 + num_det).to(device)
878+
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
879+
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
880+
selected_indices = selected_indices.to(torch.int64)
881+
return selected_indices
882+
883+
@staticmethod
884+
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
885+
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
886+
887+
888+
class TRT_NMS(torch.autograd.Function):
889+
890+
@staticmethod
891+
def forward(
892+
ctx,
893+
boxes,
894+
scores,
895+
background_class=-1,
896+
box_coding=1,
897+
iou_threshold=0.45,
898+
max_output_boxes=100,
899+
plugin_version="1",
900+
score_activation=0,
901+
score_threshold=0.25,
902+
):
903+
batch_size, num_boxes, num_classes = scores.shape
904+
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
905+
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
906+
det_scores = torch.randn(batch_size, max_output_boxes)
907+
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
908+
909+
return num_det, det_boxes, det_scores, det_classes
910+
911+
@staticmethod
912+
def symbolic(g,
913+
boxes,
914+
scores,
915+
background_class=-1,
916+
box_coding=1,
917+
iou_threshold=0.45,
918+
max_output_boxes=100,
919+
plugin_version="1",
920+
score_activation=0,
921+
score_threshold=0.25):
922+
out = g.op("TRT::EfficientNMS_TRT",
923+
boxes,
924+
scores,
925+
background_class_i=background_class,
926+
box_coding_i=box_coding,
927+
iou_threshold_f=iou_threshold,
928+
max_output_boxes_i=max_output_boxes,
929+
plugin_version_s=plugin_version,
930+
score_activation_i=score_activation,
931+
score_threshold_f=score_threshold,
932+
outputs=4)
933+
nums, boxes, scores, classes = out
934+
return nums, boxes, scores, classes
935+
936+
937+
class ONNX_ORT(nn.Module):
938+
939+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
940+
super().__init__()
941+
self.device = device if device else torch.device("cpu")
942+
self.max_obj = torch.tensor([max_obj]).to(device)
943+
self.iou_threshold = torch.tensor([iou_thres]).to(device)
944+
self.score_threshold = torch.tensor([score_thres]).to(device)
945+
self.max_wh = 7680
946+
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
947+
dtype=torch.float32,
948+
device=self.device)
949+
950+
def forward(self, x):
951+
batch, anchors, _ = x.shape
952+
boxes = x[:, :, :4]
953+
conf = x[:, :, 4:5]
954+
scores = x[:, :, 5:]
955+
scores *= conf
956+
957+
nms_box = boxes @ self.convert_matrix
958+
nms_score = scores.transpose(1, 2).contiguous()
959+
960+
selected_indices = ORT_NMS.apply(nms_box, nms_score, self.max_obj, self.iou_threshold, self.score_threshold)
961+
batch_inds, cls_inds, box_inds = selected_indices.unbind(1)
962+
selected_score = nms_score[batch_inds, cls_inds, box_inds].unsqueeze(1)
963+
selected_box = nms_box[batch_inds, box_inds, ...]
964+
965+
dets = torch.cat([selected_box, selected_score], dim=1)
966+
967+
batched_dets = dets.unsqueeze(0).repeat(batch, 1, 1)
968+
batch_template = torch.arange(0, batch, dtype=batch_inds.dtype, device=batch_inds.device)
969+
batched_dets = batched_dets.where((batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
970+
batched_dets.new_zeros(1))
971+
972+
batched_labels = cls_inds.unsqueeze(0).repeat(batch, 1)
973+
batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)),
974+
batched_labels.new_ones(1) * -1)
975+
976+
N = batched_dets.shape[0]
977+
978+
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), 1)
979+
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones((N, 1))), 1)
980+
981+
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
982+
983+
topk_batch_inds = torch.arange(batch, dtype=topk_inds.dtype, device=topk_inds.device).view(-1, 1)
984+
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
985+
labels = batched_labels[topk_batch_inds, topk_inds, ...]
986+
boxes, scores = batched_dets.split((4, 1), -1)
987+
scores = scores.squeeze(-1)
988+
num_dets = (scores > 0).sum(1, keepdim=True)
989+
return num_dets, boxes, scores, labels
990+
991+
992+
class ONNX_TRT(nn.Module):
993+
994+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
995+
super().__init__()
996+
self.device = device if device else torch.device('cpu')
997+
self.background_class = -1,
998+
self.box_coding = 1,
999+
self.iou_threshold = iou_thres
1000+
self.max_obj = max_obj
1001+
self.plugin_version = '1'
1002+
self.score_activation = 0
1003+
self.score_threshold = score_thres
1004+
1005+
def forward(self, x):
1006+
boxes = x[:, :, :4]
1007+
conf = x[:, :, 4:5]
1008+
scores = x[:, :, 5:]
1009+
scores *= conf
1010+
num_dets, boxes, scores, labels = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
1011+
self.iou_threshold, self.max_obj, self.plugin_version,
1012+
self.score_activation, self.score_threshold)
1013+
return num_dets, boxes, scores, labels
1014+
1015+
1016+
class End2End(nn.Module):
1017+
1018+
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, backend='ort', device=None):
1019+
super().__init__()
1020+
device = device if device else torch.device('cpu')
1021+
self.model = model.to(device)
1022+
1023+
if backend == 'trt':
1024+
self.patch_model = ONNX_TRT
1025+
elif backend == 'ort':
1026+
self.patch_model = ONNX_ORT
1027+
elif backend == 'ovo':
1028+
self.patch_model = ONNX_ORT
1029+
else:
1030+
raise NotImplementedError
1031+
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, device)
1032+
self.end2end.eval()
1033+
self.stride = self.model.stride
1034+
self.names = self.model.names
1035+
1036+
def forward(self, x):
1037+
x = self.model(x)[0]
1038+
x = self.end2end(x)
1039+
return x

0 commit comments

Comments
 (0)