|
8 | 8 | import json
|
9 | 9 | import math
|
10 | 10 | import platform
|
| 11 | +import random |
11 | 12 | import warnings
|
12 | 13 | import zipfile
|
13 | 14 | from collections import OrderedDict, namedtuple
|
@@ -858,3 +859,181 @@ def forward(self, x):
|
858 | 859 | if isinstance(x, list):
|
859 | 860 | x = torch.cat(x, 1)
|
860 | 861 | return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
| 862 | + |
| 863 | + |
| 864 | +class ORT_NMS(torch.autograd.Function): |
| 865 | + |
| 866 | + @staticmethod |
| 867 | + def forward(ctx, |
| 868 | + boxes, |
| 869 | + scores, |
| 870 | + max_output_boxes_per_class=torch.tensor([100]), |
| 871 | + iou_threshold=torch.tensor([0.45]), |
| 872 | + score_threshold=torch.tensor([0.25])): |
| 873 | + device = boxes.device |
| 874 | + batch = scores.shape[0] |
| 875 | + num_det = random.randint(0, 100) |
| 876 | + batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device) |
| 877 | + idxs = torch.arange(100, 100 + num_det).to(device) |
| 878 | + zeros = torch.zeros((num_det,), dtype=torch.int64).to(device) |
| 879 | + selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous() |
| 880 | + selected_indices = selected_indices.to(torch.int64) |
| 881 | + return selected_indices |
| 882 | + |
| 883 | + @staticmethod |
| 884 | + def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): |
| 885 | + return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) |
| 886 | + |
| 887 | + |
| 888 | +class TRT_NMS(torch.autograd.Function): |
| 889 | + |
| 890 | + @staticmethod |
| 891 | + def forward( |
| 892 | + ctx, |
| 893 | + boxes, |
| 894 | + scores, |
| 895 | + background_class=-1, |
| 896 | + box_coding=1, |
| 897 | + iou_threshold=0.45, |
| 898 | + max_output_boxes=100, |
| 899 | + plugin_version="1", |
| 900 | + score_activation=0, |
| 901 | + score_threshold=0.25, |
| 902 | + ): |
| 903 | + batch_size, num_boxes, num_classes = scores.shape |
| 904 | + num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) |
| 905 | + det_boxes = torch.randn(batch_size, max_output_boxes, 4) |
| 906 | + det_scores = torch.randn(batch_size, max_output_boxes) |
| 907 | + det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) |
| 908 | + |
| 909 | + return num_det, det_boxes, det_scores, det_classes |
| 910 | + |
| 911 | + @staticmethod |
| 912 | + def symbolic(g, |
| 913 | + boxes, |
| 914 | + scores, |
| 915 | + background_class=-1, |
| 916 | + box_coding=1, |
| 917 | + iou_threshold=0.45, |
| 918 | + max_output_boxes=100, |
| 919 | + plugin_version="1", |
| 920 | + score_activation=0, |
| 921 | + score_threshold=0.25): |
| 922 | + out = g.op("TRT::EfficientNMS_TRT", |
| 923 | + boxes, |
| 924 | + scores, |
| 925 | + background_class_i=background_class, |
| 926 | + box_coding_i=box_coding, |
| 927 | + iou_threshold_f=iou_threshold, |
| 928 | + max_output_boxes_i=max_output_boxes, |
| 929 | + plugin_version_s=plugin_version, |
| 930 | + score_activation_i=score_activation, |
| 931 | + score_threshold_f=score_threshold, |
| 932 | + outputs=4) |
| 933 | + nums, boxes, scores, classes = out |
| 934 | + return nums, boxes, scores, classes |
| 935 | + |
| 936 | + |
| 937 | +class ONNX_ORT(nn.Module): |
| 938 | + |
| 939 | + def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None): |
| 940 | + super().__init__() |
| 941 | + self.device = device if device else torch.device("cpu") |
| 942 | + self.max_obj = torch.tensor([max_obj]).to(device) |
| 943 | + self.iou_threshold = torch.tensor([iou_thres]).to(device) |
| 944 | + self.score_threshold = torch.tensor([score_thres]).to(device) |
| 945 | + self.max_wh = 7680 |
| 946 | + self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], |
| 947 | + dtype=torch.float32, |
| 948 | + device=self.device) |
| 949 | + |
| 950 | + def forward(self, x): |
| 951 | + batch, anchors, _ = x.shape |
| 952 | + boxes = x[:, :, :4] |
| 953 | + conf = x[:, :, 4:5] |
| 954 | + scores = x[:, :, 5:] |
| 955 | + scores *= conf |
| 956 | + |
| 957 | + nms_box = boxes @ self.convert_matrix |
| 958 | + nms_score = scores.transpose(1, 2).contiguous() |
| 959 | + |
| 960 | + selected_indices = ORT_NMS.apply(nms_box, nms_score, self.max_obj, self.iou_threshold, self.score_threshold) |
| 961 | + batch_inds, cls_inds, box_inds = selected_indices.unbind(1) |
| 962 | + selected_score = nms_score[batch_inds, cls_inds, box_inds].unsqueeze(1) |
| 963 | + selected_box = nms_box[batch_inds, box_inds, ...] |
| 964 | + |
| 965 | + dets = torch.cat([selected_box, selected_score], dim=1) |
| 966 | + |
| 967 | + batched_dets = dets.unsqueeze(0).repeat(batch, 1, 1) |
| 968 | + batch_template = torch.arange(0, batch, dtype=batch_inds.dtype, device=batch_inds.device) |
| 969 | + batched_dets = batched_dets.where((batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1), |
| 970 | + batched_dets.new_zeros(1)) |
| 971 | + |
| 972 | + batched_labels = cls_inds.unsqueeze(0).repeat(batch, 1) |
| 973 | + batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)), |
| 974 | + batched_labels.new_ones(1) * -1) |
| 975 | + |
| 976 | + N = batched_dets.shape[0] |
| 977 | + |
| 978 | + batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), 1) |
| 979 | + batched_labels = torch.cat((batched_labels, -batched_labels.new_ones((N, 1))), 1) |
| 980 | + |
| 981 | + _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) |
| 982 | + |
| 983 | + topk_batch_inds = torch.arange(batch, dtype=topk_inds.dtype, device=topk_inds.device).view(-1, 1) |
| 984 | + batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] |
| 985 | + labels = batched_labels[topk_batch_inds, topk_inds, ...] |
| 986 | + boxes, scores = batched_dets.split((4, 1), -1) |
| 987 | + scores = scores.squeeze(-1) |
| 988 | + num_dets = (scores > 0).sum(1, keepdim=True) |
| 989 | + return num_dets, boxes, scores, labels |
| 990 | + |
| 991 | + |
| 992 | +class ONNX_TRT(nn.Module): |
| 993 | + |
| 994 | + def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None): |
| 995 | + super().__init__() |
| 996 | + self.device = device if device else torch.device('cpu') |
| 997 | + self.background_class = -1, |
| 998 | + self.box_coding = 1, |
| 999 | + self.iou_threshold = iou_thres |
| 1000 | + self.max_obj = max_obj |
| 1001 | + self.plugin_version = '1' |
| 1002 | + self.score_activation = 0 |
| 1003 | + self.score_threshold = score_thres |
| 1004 | + |
| 1005 | + def forward(self, x): |
| 1006 | + boxes = x[:, :, :4] |
| 1007 | + conf = x[:, :, 4:5] |
| 1008 | + scores = x[:, :, 5:] |
| 1009 | + scores *= conf |
| 1010 | + num_dets, boxes, scores, labels = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding, |
| 1011 | + self.iou_threshold, self.max_obj, self.plugin_version, |
| 1012 | + self.score_activation, self.score_threshold) |
| 1013 | + return num_dets, boxes, scores, labels |
| 1014 | + |
| 1015 | + |
| 1016 | +class End2End(nn.Module): |
| 1017 | + |
| 1018 | + def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, backend='ort', device=None): |
| 1019 | + super().__init__() |
| 1020 | + device = device if device else torch.device('cpu') |
| 1021 | + self.model = model.to(device) |
| 1022 | + |
| 1023 | + if backend == 'trt': |
| 1024 | + self.patch_model = ONNX_TRT |
| 1025 | + elif backend == 'ort': |
| 1026 | + self.patch_model = ONNX_ORT |
| 1027 | + elif backend == 'ovo': |
| 1028 | + self.patch_model = ONNX_ORT |
| 1029 | + else: |
| 1030 | + raise NotImplementedError |
| 1031 | + self.end2end = self.patch_model(max_obj, iou_thres, score_thres, device) |
| 1032 | + self.end2end.eval() |
| 1033 | + self.stride = self.model.stride |
| 1034 | + self.names = self.model.names |
| 1035 | + |
| 1036 | + def forward(self, x): |
| 1037 | + x = self.model(x)[0] |
| 1038 | + x = self.end2end(x) |
| 1039 | + return x |
0 commit comments