Skip to content

Commit f332512

Browse files
author
Matthew Brookhart
committed
ONNX NMS working on GPU, had to remove threading from some kernels
fix lint fix lambda lift tests fix unit tests respond to review comments fix lint
1 parent 4943e00 commit f332512

File tree

11 files changed

+304
-273
lines changed

11 files changed

+304
-273
lines changed

python/tvm/relay/backend/_backend.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,6 @@ def _tensor_value_repr(tvalue):
8888
return str(tvalue.data.asnumpy())
8989

9090

91-
@tvm._ffi.register_func("relay._ndarray_repr")
92-
def _tensor_constant_repr(tvalue):
93-
tmp = tvalue.asnumpy()
94-
return "NDArray of shape " + str(tmp.shape) + " and dtype " + str(tmp.dtype) +"\n\t" + str(tmp)
95-
96-
97-
9891
@tvm._ffi.register_func("relay._constant_repr")
9992
def _tensor_constant_repr(tvalue):
10093
dtype = tvm.runtime.DataType(tvalue.data.dtype)

python/tvm/relay/frontend/onnx.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,17 @@ class NonMaxSuppression(OnnxOpConverter):
22412241

22422242
@classmethod
22432243
def _impl_v10(cls, inputs, attr, params):
2244+
"""
2245+
High level note: ONNX implements what TF calls combined_non_max_suppression
2246+
It passes in scores for each box for every class in the output and expects boxes to be
2247+
analyzed for each class independently
2248+
2249+
It also asks for the data to be returned in a particular format.
2250+
2251+
To support these, we implement a series of lops:
2252+
The first loop splits over class number, performs NMS, and collects the outputs.
2253+
The second (nested) loop takes the outputs and transforms them into the format ONNX wants
2254+
"""
22442255
# Get parameter values
22452256
boxes = inputs[0]
22462257
scores = inputs[1]
@@ -2270,17 +2281,17 @@ def conditionally_squeeze_scalar(x):
22702281
max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
22712282
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
22722283
score_threshold = conditionally_squeeze_scalar(score_threshold)
2284+
2285+
## prepare utility constants
22732286
zero = _op.const(np.array([0]), dtype="int64")
22742287
one = _op.const(np.array([1]), dtype="int64")
2288+
two = _op.const(np.array([2]), dtype="int64")
22752289
three = _op.const(np.array([3]), dtype="int64")
2276-
two_ones = _op.const(np.array([1, 1]), dtype="int64")
22772290
three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
22782291
four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")
22792292

2280-
def pad_last_dim(x):
2281-
return _op.expand_dims(x, -1, 1)
2282-
2283-
# First Loop Vars
2293+
## First loop: split by class and perform NMS
2294+
# Create Loop Vars
22842295
i = _expr.var("i", shape=(1,), dtype="int64")
22852296
scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
22862297
boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
@@ -2292,7 +2303,7 @@ def pad_last_dim(x):
22922303
B = _expr.var("B", shape=(1,), dtype="int64")
22932304
C = _expr.var("C", shape=(1,), dtype="int64")
22942305
S = _expr.var("S", shape=(1,), dtype="int64")
2295-
# Outputs of first loop should be padded nms values shape (B, C, 3)
2306+
# Outputs of first loop should be padded nms values shape (B, C, S, 3)
22962307
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
22972308
# and sizes of valid outputs, shape (B, C, 1)
22982309
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
@@ -2310,6 +2321,7 @@ def _first_cond(
23102321
onnx_out,
23112322
nms_size_out,
23122323
):
2324+
# Loop over classes, end when i == C
23132325
return _op.min(_op.less(i, C))
23142326

23152327
def _first_body(
@@ -2325,12 +2337,15 @@ def _first_body(
23252337
onnx_out,
23262338
nms_size_out,
23272339
):
2340+
# slice to get current class
23282341
begin = _op.concatenate([zero, i, zero], axis=0)
23292342
end = _op.concatenate([B, i + one, S], axis=0)
23302343
class_scores = _op.strided_slice(scores, begin, end, three_ones)
23312344
class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1)
2345+
# combine scores and boxes
23322346
data = _op.concatenate([class_scores, boxes], axis=-1)
23332347

2348+
# get valid counts
23342349
ct, data, indices = _op.vision.get_valid_counts(
23352350
data, score_threshold=score_threshold, id_index=-1, score_index=0
23362351
)
@@ -2339,6 +2354,7 @@ def _first_body(
23392354
top_k = -1
23402355
# ONNX doesn't have class id for nms input
23412356
score_index = 0
2357+
# perform nms on current class
23422358
nms_ret = _op.vision.non_max_suppression(
23432359
data=data,
23442360
valid_count=ct,
@@ -2353,6 +2369,7 @@ def _first_body(
23532369
return_indices=True,
23542370
invalid_to_bottom=False,
23552371
)
2372+
# partially prepare ONNX output format by labeling batch_num, class_id
23562373
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
23572374
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
23582375
batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
@@ -2362,6 +2379,7 @@ def _first_body(
23622379
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
23632380
)
23642381
new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
2382+
# store valid nms outputs for this class
23652383
nms_size = _op.cast(nms_ret[1], "int64")
23662384
nms_size = _op.expand_dims(nms_size, 1, 1)
23672385
return [
@@ -2378,6 +2396,7 @@ def _first_body(
23782396
_op.concatenate([nms_size_out, nms_size], axis=1),
23792397
]
23802398

2399+
# create the first loop
23812400
first_loop = _loops.while_loop(
23822401
_first_cond,
23832402
[
@@ -2396,6 +2415,8 @@ def _first_body(
23962415
_first_body,
23972416
)
23982417

2418+
## Second loop slices outputs of the first loop for valid boxes and
2419+
## concats in the order ONNX wants
23992420
# Second inner Loop Vars
24002421
i = _expr.var("i", shape=(1,), dtype="int64")
24012422
j = _expr.var("j", shape=(1,), dtype="int64")
@@ -2408,14 +2429,17 @@ def _first_body(
24082429
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
24092430

24102431
def _inner_cond(i, j, C, onnx_out, nms_size, out):
2432+
# inner loop over number of classes
24112433
return _op.min(_op.less(j, C))
24122434

24132435
def _inner_body(i, j, C, onnx_out, nms_size, out):
2414-
start = _op.concatenate([i, j, zero], axis=0)
2415-
end = _op.concatenate([i + one, j + one, one], axis=0)
2436+
# slice to get current batch and class for valid box indicator
2437+
start = _op.concatenate([i, j + one, zero], axis=0)
2438+
end = _op.concatenate([i + one, j + two, one], axis=0)
24162439
num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1])
2417-
start = _op.concatenate([i, j, zero, zero], axis=0)
2418-
end = _op.concatenate([i + one, j + one, num_valid_boxes, three], axis=0)
2440+
# slice to get current batch, class, and valid outputs
2441+
start = _op.concatenate([i, j + one, zero, zero], axis=0)
2442+
end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0)
24192443
new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1])
24202444
return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0)
24212445

@@ -2435,23 +2459,27 @@ def _inner_body(i, j, C, onnx_out, nms_size, out):
24352459
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
24362460

24372461
def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
2462+
# Outer loop is over batch size
24382463
return _op.min(_op.less(i, B))
24392464

24402465
def _outer_body(i, B, C, onnx_out, nms_size_out, out):
2466+
# Outer loop just calls inner loop
24412467
init_count = _op.const(np.array([0]), dtype="int64")
24422468
inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out)
24432469
return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5)
24442470

2471+
# Create the second loop
24452472
outer_loop = _loops.while_loop(
24462473
_outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
24472474
)
24482475

2476+
# Call the first loop, perform NMS
24492477
B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
24502478
init_count = _op.const(np.array([0]), dtype="int64")
2451-
init_onnx_out = _op.const([], dtype="int64")
2452-
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, zero, S, three], 0))
2453-
init_nms_size_out = _op.const([], dtype="int64")
2454-
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, zero, one], 0))
2479+
init_onnx_out = _op.const([1], dtype="int64")
2480+
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
2481+
init_nms_size_out = _op.const([1], dtype="int64")
2482+
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0))
24552483
loop_vals = first_loop(
24562484
init_count,
24572485
scores,
@@ -2468,9 +2496,11 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out):
24682496
onnx_output = _expr.TupleGetItem(loop_vals, 9)
24692497
nms_size_output = _expr.TupleGetItem(loop_vals, 10)
24702498

2499+
# Call the second loop, rework outputs into correct form
24712500
init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
24722501
init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), dtype="int64")
24732502
loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out)
2503+
24742504
return _expr.TupleGetItem(loop_vals, 5)
24752505

24762506

python/tvm/relay/op/vision/nms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_valid_counts(data, score_threshold, id_index=0, score_index=1):
4848
out_indices: relay.Expr
4949
Indices in input data
5050
"""
51-
if isinstance(score_threshold, float):
51+
if not isinstance(score_threshold, expr.Expr):
5252
score_threshold = expr.const(score_threshold, "float32")
5353
return expr.TupleWrapper(
5454
_make.get_valid_counts(data, score_threshold, id_index, score_index), 3
@@ -128,9 +128,9 @@ def non_max_suppression(
128128
If return_indices is True, return relay.Tuple of two 2-D tensors, with
129129
shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
130130
"""
131-
if isinstance(max_output_size, int):
131+
if not isinstance(max_output_size, expr.Expr):
132132
max_output_size = expr.const(max_output_size, "int32")
133-
if isinstance(iou_threshold, float):
133+
if not isinstance(iou_threshold, expr.Expr):
134134
iou_threshold = expr.const(iou_threshold, "float32")
135135
out = _make.non_max_suppression(
136136
data,

0 commit comments

Comments
 (0)