Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,74 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
return f, model_onnx


@try_export
def export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
check_requirements('onnx>=1.12.0')
import onnx

LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

from models.common import End2End
model = End2End(model, *nms_cfg, device=im.device)
b, topk, backend = 'batch', nms_cfg[0], nms_cfg[-1]
output_names = ['num_dets', 'boxes', 'scores', 'labels']
output_shapes = {n: {0: b} for n in output_names}
if dynamic == 'batch':
dynamic_cfg = {'images': {0: b}, **output_shapes}
elif dynamic == 'all':
dynamic_cfg = {'images': {0: b, 2: 'height', 3: 'width'}, **output_shapes}
else:
dynamic_cfg, b = {}, im.shape[0]

torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
f,
verbose=False,
opset_version=opset,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_cfg)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model

# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)

# Fix shape info for onnx using by TensorRT
if backend == 'trt':
shapes = [b, 1, b, topk, 4, b, topk, b, topk]
else:
shapes = [b, 1, b, 'topk', 4, b, 'topk', b, 'topk']
for i in model_onnx.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
onnx.save(model_onnx, f)

# Simplify
if simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim

LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
return f, model_onnx


@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
Expand Down Expand Up @@ -505,7 +573,7 @@ def run(
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
nms=False, # ONNX/TF/TensorRT: NMS config for model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
Expand Down Expand Up @@ -560,9 +628,9 @@ def run(
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
if not nms and onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if xml: # OpenVINO
if not nms and xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
f[4], _ = export_coreml(model, im, file, int8, half)
Expand Down Expand Up @@ -592,6 +660,11 @@ def run(
if paddle: # PaddlePaddle
f[10], _ = export_paddle(model, im, file, metadata)

if nms and (onnx or xml):
nms_cfg = [topk_all, iou_thres, conf_thres, nms]
f.append(export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify)[0])
if xml:
f.append(export_openvino(file.with_suffix('.pt'), metadata, half)[0])
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
Expand Down Expand Up @@ -622,12 +695,12 @@ def parse_opt(known=False):
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--dynamic', nargs='?', const='all', default=False, help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
parser.add_argument('--nms', nargs='?', const=True, default=False, help='ONNX/TF/TensorRT: NMS config for model')
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
Expand Down
179 changes: 179 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import math
import platform
import random
import warnings
import zipfile
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -865,3 +866,181 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))


class ORT_NMS(torch.autograd.Function):

@staticmethod
def forward(ctx,
boxes,
scores,
max_output_boxes_per_class=torch.tensor([100]),
iou_threshold=torch.tensor([0.45]),
score_threshold=torch.tensor([0.25])):
device = boxes.device
batch = scores.shape[0]
num_det = random.randint(0, 100)
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices

@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)


class TRT_NMS(torch.autograd.Function):

@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)

return num_det, det_boxes, det_scores, det_classes

@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes


class ONNX_ORT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
self.iou_threshold = torch.tensor([iou_thres]).to(device)
self.score_threshold = torch.tensor([score_thres]).to(device)
self.max_wh = 7680
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)

def forward(self, x):
batch, anchors, _ = x.shape
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf

nms_box = boxes @ self.convert_matrix
nms_score = scores.transpose(1, 2).contiguous()

selected_indices = ORT_NMS.apply(nms_box, nms_score, self.max_obj, self.iou_threshold, self.score_threshold)
batch_inds, cls_inds, box_inds = selected_indices.unbind(1)
selected_score = nms_score[batch_inds, cls_inds, box_inds].unsqueeze(1)
selected_box = nms_box[batch_inds, box_inds, ...]

dets = torch.cat([selected_box, selected_score], dim=1)

batched_dets = dets.unsqueeze(0).repeat(batch, 1, 1)
batch_template = torch.arange(0, batch, dtype=batch_inds.dtype, device=batch_inds.device)
batched_dets = batched_dets.where((batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
batched_dets.new_zeros(1))

batched_labels = cls_inds.unsqueeze(0).repeat(batch, 1)
batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)),
batched_labels.new_ones(1) * -1)

N = batched_dets.shape[0]

batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), 1)
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones((N, 1))), 1)

_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)

topk_batch_inds = torch.arange(batch, dtype=topk_inds.dtype, device=topk_inds.device).view(-1, 1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
labels = batched_labels[topk_batch_inds, topk_inds, ...]
boxes, scores = batched_dets.split((4, 1), -1)
scores = scores.squeeze(-1)
num_dets = (labels > -1).sum(1, keepdim=True)
return num_dets, boxes, scores, labels


class ONNX_TRT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres

def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
num_dets, boxes, scores, labels = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj, self.plugin_version,
self.score_activation, self.score_threshold)
return num_dets, boxes, scores, labels


class End2End(nn.Module):

def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, backend='ort', device=None):
super().__init__()
device = device if device else torch.device('cpu')
self.model = model.to(device)

if backend == 'trt':
self.patch_model = ONNX_TRT
elif backend == 'ort':
self.patch_model = ONNX_ORT
elif backend == 'ovo':
self.patch_model = ONNX_ORT
else:
raise NotImplementedError
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, device)
self.end2end.eval()
self.stride = self.model.stride
self.names = self.model.names

def forward(self, x):
x = self.model(x)[0]
x = self.end2end(x)
return x
Loading