Skip to content

Commit 054466b

Browse files
Matthew Brookhartjroesch
andauthored
[ONNX] NMS in ONNX (#6839)
* NMS partially working on CPU, fails on GPU * support dynamic iou_threshold * WIP NMS with while loops * working nms with dynamic shapes * add a test with dynamic score_threshold and pass it * Fix type checking in lambda lift * ONNX NMS working on GPU, had to remove threading from some kernels fix lint fix lambda lift tests fix unit tests respond to review comments fix lint * better parallelize get_valid_counts * improve nms parallelization * respond to cuda/thrust enablement issue Co-authored-by: Jared Roesch <roeschinc@gmail.com>
1 parent 6be4d0a commit 054466b

File tree

14 files changed

+827
-210
lines changed

14 files changed

+827
-210
lines changed

include/tvm/relay/attrs/vision.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,12 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAtt
7373

7474
/*! \brief Attributes used in get_valid_counts operator */
7575
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
76-
double score_threshold;
76+
Optional<FloatImm> score_threshold;
7777
int id_index;
7878
int score_index;
7979

8080
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
81-
TVM_ATTR_FIELD(score_threshold)
82-
.set_default(0.0)
83-
.describe("Lower limit of score for valid bounding boxes.");
81+
TVM_ATTR_FIELD(score_threshold).describe("Lower limit of score for valid bounding boxes.");
8482
TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id.");
8583
TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes.");
8684
}
@@ -89,7 +87,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
8987
/*! \brief Attributes used in non_maximum_suppression operator */
9088
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
9189
Optional<Integer> max_output_size;
92-
double iou_threshold;
90+
Optional<FloatImm> iou_threshold;
9391
bool force_suppress;
9492
int top_k;
9593
int coord_start;
@@ -100,9 +98,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
10098

10199
TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
102100
TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
103-
TVM_ATTR_FIELD(iou_threshold)
104-
.set_default(0.5)
105-
.describe("Non-maximum suppression iou threshold.");
101+
TVM_ATTR_FIELD(iou_threshold).describe("Non-maximum suppression iou threshold.");
106102
TVM_ATTR_FIELD(force_suppress)
107103
.set_default(false)
108104
.describe("Suppress all detections regardless of class_id.");

python/tvm/relay/frontend/onnx.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,274 @@ def _impl_v1(cls, inputs, attr, params):
23032303
return _expr.If(cond, then_expr, else_expr)
23042304

23052305

2306+
class NonMaxSuppression(OnnxOpConverter):
2307+
"""Operator converter for NonMaxSuppression."""
2308+
2309+
@classmethod
2310+
def _impl_v10(cls, inputs, attr, params):
2311+
"""
2312+
High level note: ONNX implements what TF calls combined_non_max_suppression
2313+
It passes in scores for each box for every class in the output and expects boxes to be
2314+
analyzed for each class independently
2315+
2316+
It also asks for the data to be returned in a particular format.
2317+
2318+
To support these, we implement a series of lops:
2319+
The first loop splits over class number, performs NMS, and collects the outputs.
2320+
The second (nested) loop takes the outputs and transforms them into the format ONNX wants
2321+
"""
2322+
# Get parameter values
2323+
boxes = inputs[0]
2324+
scores = inputs[1]
2325+
max_output_boxes_per_class = inputs[2]
2326+
iou_threshold = inputs[3]
2327+
score_threshold = inputs[4]
2328+
2329+
dtype = infer_type(boxes).checked_type.dtype
2330+
2331+
if "center_point_box" in attr:
2332+
assert (
2333+
attr["center_point_box"] == 0
2334+
), "Only support center_point_box = 0 in onnx importer right now"
2335+
2336+
if iou_threshold is None:
2337+
iou_threshold = _expr.const(0.0, dtype="float32")
2338+
if score_threshold is None:
2339+
score_threshold = _expr.const(0.0, dtype="float32")
2340+
2341+
def conditionally_squeeze_scalar(x):
2342+
rank = len(infer_shape(x))
2343+
assert rank <= 1, "nms thresholds must be scalars"
2344+
if rank == 1:
2345+
return _op.squeeze(x, [0])
2346+
return x
2347+
2348+
max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
2349+
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
2350+
score_threshold = conditionally_squeeze_scalar(score_threshold)
2351+
2352+
## prepare utility constants
2353+
zero = _op.const(np.array([0]), dtype="int64")
2354+
one = _op.const(np.array([1]), dtype="int64")
2355+
two = _op.const(np.array([2]), dtype="int64")
2356+
three = _op.const(np.array([3]), dtype="int64")
2357+
three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
2358+
four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")
2359+
2360+
## First loop: split by class and perform NMS
2361+
# Create Loop Vars
2362+
i = _expr.var("i", shape=(1,), dtype="int64")
2363+
scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
2364+
boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
2365+
max_output_boxes_per_class_var = _expr.var(
2366+
"max_output_boxes_per_class_var", shape=(), dtype="int64"
2367+
)
2368+
iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
2369+
score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
2370+
B = _expr.var("B", shape=(1,), dtype="int64")
2371+
C = _expr.var("C", shape=(1,), dtype="int64")
2372+
S = _expr.var("S", shape=(1,), dtype="int64")
2373+
# Outputs of first loop should be padded nms values shape (B, C, S, 3)
2374+
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
2375+
# and sizes of valid outputs, shape (B, C, 1)
2376+
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
2377+
2378+
def _first_cond(
2379+
i,
2380+
scores,
2381+
boxes,
2382+
B,
2383+
C,
2384+
S,
2385+
max_output_boxes_per_class,
2386+
iou_threshold,
2387+
score_threshold,
2388+
onnx_out,
2389+
nms_size_out,
2390+
):
2391+
# Loop over classes, end when i == C
2392+
return _op.min(_op.less(i, C))
2393+
2394+
def _first_body(
2395+
i,
2396+
scores,
2397+
boxes,
2398+
B,
2399+
C,
2400+
S,
2401+
max_output_boxes_per_class,
2402+
iou_threshold,
2403+
score_threshold,
2404+
onnx_out,
2405+
nms_size_out,
2406+
):
2407+
# slice to get current class
2408+
begin = _op.concatenate([zero, i, zero], axis=0)
2409+
end = _op.concatenate([B, i + one, S], axis=0)
2410+
class_scores = _op.strided_slice(scores, begin, end, three_ones)
2411+
class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1)
2412+
# combine scores and boxes
2413+
data = _op.concatenate([class_scores, boxes], axis=-1)
2414+
2415+
# get valid counts
2416+
ct, data, indices = _op.vision.get_valid_counts(
2417+
data, score_threshold=score_threshold, id_index=-1, score_index=0
2418+
)
2419+
# reason why using get_valid_counts is for inference performance
2420+
# ONNX NMS doesn't have parameter top_k
2421+
top_k = -1
2422+
# ONNX doesn't have class id for nms input
2423+
score_index = 0
2424+
# perform nms on current class
2425+
nms_ret = _op.vision.non_max_suppression(
2426+
data=data,
2427+
valid_count=ct,
2428+
indices=indices,
2429+
max_output_size=max_output_boxes_per_class,
2430+
iou_threshold=iou_threshold,
2431+
force_suppress=True,
2432+
top_k=top_k,
2433+
coord_start=1,
2434+
score_index=score_index,
2435+
id_index=-1,
2436+
return_indices=True,
2437+
invalid_to_bottom=False,
2438+
)
2439+
# partially prepare ONNX output format by labeling batch_num, class_id
2440+
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
2441+
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
2442+
batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
2443+
batch_num = _op.expand_dims(batch_num, -1, 1)
2444+
class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64"))
2445+
new_onnx_out = _op.concatenate(
2446+
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
2447+
)
2448+
new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
2449+
# store valid nms outputs for this class
2450+
nms_size = _op.cast(nms_ret[1], "int64")
2451+
nms_size = _op.expand_dims(nms_size, 1, 1)
2452+
return [
2453+
i + one,
2454+
scores,
2455+
boxes,
2456+
B,
2457+
C,
2458+
S,
2459+
max_output_boxes_per_class,
2460+
iou_threshold,
2461+
score_threshold,
2462+
_op.concatenate([onnx_out, new_onnx_out], axis=1),
2463+
_op.concatenate([nms_size_out, nms_size], axis=1),
2464+
]
2465+
2466+
# create the first loop
2467+
first_loop = _loops.while_loop(
2468+
_first_cond,
2469+
[
2470+
i,
2471+
scores_var,
2472+
boxes_var,
2473+
B,
2474+
C,
2475+
S,
2476+
max_output_boxes_per_class_var,
2477+
iou_threshold_var,
2478+
score_threshold_var,
2479+
onnx_out,
2480+
nms_size_out,
2481+
],
2482+
_first_body,
2483+
)
2484+
2485+
## Second loop slices outputs of the first loop for valid boxes and
2486+
## concats in the order ONNX wants
2487+
# Second inner Loop Vars
2488+
i = _expr.var("i", shape=(1,), dtype="int64")
2489+
j = _expr.var("j", shape=(1,), dtype="int64")
2490+
B = _expr.var("B", shape=(1,), dtype="int64")
2491+
C = _expr.var("C", shape=(1,), dtype="int64")
2492+
# Outputs of first loop should be padded nms values shape (B, C, 3)
2493+
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
2494+
# and sizes of valid outputs, shape (B, C, 1)
2495+
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
2496+
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
2497+
2498+
def _inner_cond(i, j, C, onnx_out, nms_size, out):
2499+
# inner loop over number of classes
2500+
return _op.min(_op.less(j, C))
2501+
2502+
def _inner_body(i, j, C, onnx_out, nms_size, out):
2503+
# slice to get current batch and class for valid box indicator
2504+
start = _op.concatenate([i, j + one, zero], axis=0)
2505+
end = _op.concatenate([i + one, j + two, one], axis=0)
2506+
num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1])
2507+
# slice to get current batch, class, and valid outputs
2508+
start = _op.concatenate([i, j + one, zero, zero], axis=0)
2509+
end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0)
2510+
new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1])
2511+
return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0)
2512+
2513+
inner_loop = _loops.while_loop(
2514+
_inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body
2515+
)
2516+
2517+
# Second Outer Loop Vars
2518+
i = _expr.var("i", shape=(1,), dtype="int64")
2519+
j = _expr.var("j", shape=(1,), dtype="int64")
2520+
B = _expr.var("B", shape=(1,), dtype="int64")
2521+
C = _expr.var("C", shape=(1,), dtype="int64")
2522+
# Outputs of first loop should be padded nms values shape (B, C, 3)
2523+
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
2524+
# and sizes of valid outputs, shape (B, C, 1)
2525+
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
2526+
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
2527+
2528+
def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
2529+
# Outer loop is over batch size
2530+
return _op.min(_op.less(i, B))
2531+
2532+
def _outer_body(i, B, C, onnx_out, nms_size_out, out):
2533+
# Outer loop just calls inner loop
2534+
init_count = _op.const(np.array([0]), dtype="int64")
2535+
inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out)
2536+
return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5)
2537+
2538+
# Create the second loop
2539+
outer_loop = _loops.while_loop(
2540+
_outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
2541+
)
2542+
2543+
# Call the first loop, perform NMS
2544+
B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
2545+
init_count = _op.const(np.array([0]), dtype="int64")
2546+
init_onnx_out = _op.const([1], dtype="int64")
2547+
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
2548+
init_nms_size_out = _op.const([1], dtype="int64")
2549+
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0))
2550+
loop_vals = first_loop(
2551+
init_count,
2552+
scores,
2553+
boxes,
2554+
B,
2555+
C,
2556+
S,
2557+
max_output_boxes_per_class,
2558+
iou_threshold,
2559+
score_threshold,
2560+
init_onnx_out,
2561+
init_nms_size_out,
2562+
)
2563+
onnx_output = _expr.TupleGetItem(loop_vals, 9)
2564+
nms_size_output = _expr.TupleGetItem(loop_vals, 10)
2565+
2566+
# Call the second loop, rework outputs into correct form
2567+
init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
2568+
init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), dtype="int64")
2569+
loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out)
2570+
2571+
return _expr.TupleGetItem(loop_vals, 5)
2572+
2573+
23062574
# compatible operators that do NOT require any conversion.
23072575
_identity_list = []
23082576

@@ -2415,6 +2683,7 @@ def _get_convert_map(opset):
24152683
# defs/vision
24162684
"MaxRoiPool": MaxRoiPool.get_converter(opset),
24172685
"RoiAlign": RoiAlign.get_converter(opset),
2686+
"NonMaxSuppression": NonMaxSuppression.get_converter(opset),
24182687
# defs/reduction
24192688
"ReduceMax": ReduceMax.get_converter(opset),
24202689
"ReduceMin": ReduceMin.get_converter(opset),

python/tvm/relay/op/strategy/generic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,9 +885,11 @@ def wrap_compute_get_valid_counts(topi_compute):
885885
"""wrap get_valid_counts topi compute"""
886886

887887
def _compute_get_valid_counts(attrs, inputs, out_type):
888-
score_threshold = get_const_float(attrs.score_threshold)
888+
score_threshold = inputs[1]
889889
id_index = get_const_int(attrs.id_index)
890890
score_index = get_const_int(attrs.score_index)
891+
if attrs.score_threshold is not None:
892+
score_threshold = get_const_float(attrs.score_threshold)
891893
return topi_compute(inputs[0], score_threshold, id_index, score_index)
892894

893895
return _compute_get_valid_counts
@@ -911,10 +913,12 @@ def wrap_compute_nms(topi_compute):
911913

912914
def _compute_nms(attrs, inputs, out_type):
913915
max_output_size = inputs[3]
916+
iou_threshold = inputs[4]
914917
if attrs.max_output_size is not None:
915918
max_output_size = attrs.max_output_size
919+
if attrs.iou_threshold is not None:
920+
iou_threshold = get_const_float(attrs.iou_threshold)
916921
return_indices = bool(get_const_int(attrs.return_indices))
917-
iou_threshold = get_const_float(attrs.iou_threshold)
918922
force_suppress = bool(get_const_int(attrs.force_suppress))
919923
top_k = get_const_int(attrs.top_k)
920924
coord_start = get_const_int(attrs.coord_start)

python/tvm/relay/op/vision/nms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def get_valid_counts(data, score_threshold, id_index=0, score_index=1):
4848
out_indices: relay.Expr
4949
Indices in input data
5050
"""
51+
if not isinstance(score_threshold, expr.Expr):
52+
score_threshold = expr.const(score_threshold, "float32")
5153
return expr.TupleWrapper(
5254
_make.get_valid_counts(data, score_threshold, id_index, score_index), 3
5355
)
@@ -94,7 +96,7 @@ def non_max_suppression(
9496
Max number of output valid boxes for each instance.
9597
Return all valid boxes if the value of max_output_size is less than 0.
9698
97-
iou_threshold : float, optional
99+
iou_threshold : float or relay.Expr, optional
98100
Non-maximum suppression threshold.
99101
100102
force_suppress : bool, optional
@@ -126,8 +128,10 @@ def non_max_suppression(
126128
If return_indices is True, return relay.Tuple of two 2-D tensors, with
127129
shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
128130
"""
129-
if isinstance(max_output_size, int):
131+
if not isinstance(max_output_size, expr.Expr):
130132
max_output_size = expr.const(max_output_size, "int32")
133+
if not isinstance(iou_threshold, expr.Expr):
134+
iou_threshold = expr.const(iou_threshold, "float32")
131135
out = _make.non_max_suppression(
132136
data,
133137
valid_count,

0 commit comments

Comments
 (0)