@@ -2241,6 +2241,17 @@ class NonMaxSuppression(OnnxOpConverter):
2241
2241
2242
2242
@classmethod
2243
2243
def _impl_v10 (cls , inputs , attr , params ):
2244
+ """
2245
+ High level note: ONNX implements what TF calls combined_non_max_suppression
2246
+ It passes in scores for each box for every class in the output and expects boxes to be
2247
+ analyzed for each class independently
2248
+
2249
+ It also asks for the data to be returned in a particular format.
2250
+
2251
+ To support these, we implement a series of lops:
2252
+ The first loop splits over class number, performs NMS, and collects the outputs.
2253
+ The second (nested) loop takes the outputs and transforms them into the format ONNX wants
2254
+ """
2244
2255
# Get parameter values
2245
2256
boxes = inputs [0 ]
2246
2257
scores = inputs [1 ]
@@ -2270,17 +2281,17 @@ def conditionally_squeeze_scalar(x):
2270
2281
max_output_boxes_per_class = conditionally_squeeze_scalar (max_output_boxes_per_class )
2271
2282
iou_threshold = conditionally_squeeze_scalar (iou_threshold )
2272
2283
score_threshold = conditionally_squeeze_scalar (score_threshold )
2284
+
2285
+ ## prepare utility constants
2273
2286
zero = _op .const (np .array ([0 ]), dtype = "int64" )
2274
2287
one = _op .const (np .array ([1 ]), dtype = "int64" )
2288
+ two = _op .const (np .array ([2 ]), dtype = "int64" )
2275
2289
three = _op .const (np .array ([3 ]), dtype = "int64" )
2276
- two_ones = _op .const (np .array ([1 , 1 ]), dtype = "int64" )
2277
2290
three_ones = _op .const (np .array ([1 , 1 , 1 ]), dtype = "int64" )
2278
2291
four_ones = _op .const (np .array ([1 , 1 , 1 , 1 ]), dtype = "int64" )
2279
2292
2280
- def pad_last_dim (x ):
2281
- return _op .expand_dims (x , - 1 , 1 )
2282
-
2283
- # First Loop Vars
2293
+ ## First loop: split by class and perform NMS
2294
+ # Create Loop Vars
2284
2295
i = _expr .var ("i" , shape = (1 ,), dtype = "int64" )
2285
2296
scores_var = _expr .var ("scores_var" , shape = (_ty .Any (), _ty .Any (), _ty .Any ()), dtype = dtype )
2286
2297
boxes_var = _expr .var ("boxes_var" , shape = (_ty .Any (), _ty .Any (), 4 ), dtype = dtype )
@@ -2292,7 +2303,7 @@ def pad_last_dim(x):
2292
2303
B = _expr .var ("B" , shape = (1 ,), dtype = "int64" )
2293
2304
C = _expr .var ("C" , shape = (1 ,), dtype = "int64" )
2294
2305
S = _expr .var ("S" , shape = (1 ,), dtype = "int64" )
2295
- # Outputs of first loop should be padded nms values shape (B, C, 3)
2306
+ # Outputs of first loop should be padded nms values shape (B, C, S, 3)
2296
2307
onnx_out = _expr .var ("onnx_out" , shape = (_ty .Any (), _ty .Any (), _ty .Any (), 3 ), dtype = "int64" )
2297
2308
# and sizes of valid outputs, shape (B, C, 1)
2298
2309
nms_size_out = _expr .var ("nms_size_out" , shape = (_ty .Any (), _ty .Any (), 1 ), dtype = "int64" )
@@ -2310,6 +2321,7 @@ def _first_cond(
2310
2321
onnx_out ,
2311
2322
nms_size_out ,
2312
2323
):
2324
+ # Loop over classes, end when i == C
2313
2325
return _op .min (_op .less (i , C ))
2314
2326
2315
2327
def _first_body (
@@ -2325,12 +2337,15 @@ def _first_body(
2325
2337
onnx_out ,
2326
2338
nms_size_out ,
2327
2339
):
2340
+ # slice to get current class
2328
2341
begin = _op .concatenate ([zero , i , zero ], axis = 0 )
2329
2342
end = _op .concatenate ([B , i + one , S ], axis = 0 )
2330
2343
class_scores = _op .strided_slice (scores , begin , end , three_ones )
2331
2344
class_scores = _op .expand_dims (_op .squeeze (class_scores , [1 ]), - 1 , 1 )
2345
+ # combine scores and boxes
2332
2346
data = _op .concatenate ([class_scores , boxes ], axis = - 1 )
2333
2347
2348
+ # get valid counts
2334
2349
ct , data , indices = _op .vision .get_valid_counts (
2335
2350
data , score_threshold = score_threshold , id_index = - 1 , score_index = 0
2336
2351
)
@@ -2339,6 +2354,7 @@ def _first_body(
2339
2354
top_k = - 1
2340
2355
# ONNX doesn't have class id for nms input
2341
2356
score_index = 0
2357
+ # perform nms on current class
2342
2358
nms_ret = _op .vision .non_max_suppression (
2343
2359
data = data ,
2344
2360
valid_count = ct ,
@@ -2353,6 +2369,7 @@ def _first_body(
2353
2369
return_indices = True ,
2354
2370
invalid_to_bottom = False ,
2355
2371
)
2372
+ # partially prepare ONNX output format by labeling batch_num, class_id
2356
2373
nms_padded_out = _op .expand_dims (nms_ret [0 ], - 1 , 1 )
2357
2374
batch_num = _op .expand_dims (_op .arange (_op .squeeze (B , [0 ]), dtype = "int64" ), - 1 , 1 )
2358
2375
batch_num = _op .broadcast_to (batch_num , _op .shape_of (nms_ret [0 ], dtype = "int64" ))
@@ -2362,6 +2379,7 @@ def _first_body(
2362
2379
[batch_num , class_num , _op .cast (nms_padded_out , "int64" )], - 1
2363
2380
)
2364
2381
new_onnx_out = _op .expand_dims (new_onnx_out , 1 , 1 )
2382
+ # store valid nms outputs for this class
2365
2383
nms_size = _op .cast (nms_ret [1 ], "int64" )
2366
2384
nms_size = _op .expand_dims (nms_size , 1 , 1 )
2367
2385
return [
@@ -2378,6 +2396,7 @@ def _first_body(
2378
2396
_op .concatenate ([nms_size_out , nms_size ], axis = 1 ),
2379
2397
]
2380
2398
2399
+ # create the first loop
2381
2400
first_loop = _loops .while_loop (
2382
2401
_first_cond ,
2383
2402
[
@@ -2396,6 +2415,8 @@ def _first_body(
2396
2415
_first_body ,
2397
2416
)
2398
2417
2418
+ ## Second loop slices outputs of the first loop for valid boxes and
2419
+ ## concats in the order ONNX wants
2399
2420
# Second inner Loop Vars
2400
2421
i = _expr .var ("i" , shape = (1 ,), dtype = "int64" )
2401
2422
j = _expr .var ("j" , shape = (1 ,), dtype = "int64" )
@@ -2408,14 +2429,17 @@ def _first_body(
2408
2429
out = _expr .var ("out" , shape = (_ty .Any (), 3 ), dtype = "int64" )
2409
2430
2410
2431
def _inner_cond (i , j , C , onnx_out , nms_size , out ):
2432
+ # inner loop over number of classes
2411
2433
return _op .min (_op .less (j , C ))
2412
2434
2413
2435
def _inner_body (i , j , C , onnx_out , nms_size , out ):
2414
- start = _op .concatenate ([i , j , zero ], axis = 0 )
2415
- end = _op .concatenate ([i + one , j + one , one ], axis = 0 )
2436
+ # slice to get current batch and class for valid box indicator
2437
+ start = _op .concatenate ([i , j + one , zero ], axis = 0 )
2438
+ end = _op .concatenate ([i + one , j + two , one ], axis = 0 )
2416
2439
num_valid_boxes = _op .reshape (_op .strided_slice (nms_size , start , end , three_ones ), [1 ])
2417
- start = _op .concatenate ([i , j , zero , zero ], axis = 0 )
2418
- end = _op .concatenate ([i + one , j + one , num_valid_boxes , three ], axis = 0 )
2440
+ # slice to get current batch, class, and valid outputs
2441
+ start = _op .concatenate ([i , j + one , zero , zero ], axis = 0 )
2442
+ end = _op .concatenate ([i + one , j + two , num_valid_boxes , three ], axis = 0 )
2419
2443
new_out = _op .squeeze (_op .strided_slice (onnx_out , start , end , four_ones ), [0 , 1 ])
2420
2444
return i , j + one , C , onnx_out , nms_size , _op .concatenate ([out , new_out ], axis = 0 )
2421
2445
@@ -2435,23 +2459,27 @@ def _inner_body(i, j, C, onnx_out, nms_size, out):
2435
2459
out = _expr .var ("out" , shape = (_ty .Any (), 3 ), dtype = "int64" )
2436
2460
2437
2461
def _outer_cond (i , B , C , onnx_out , nms_size_out , out ):
2462
+ # Outer loop is over batch size
2438
2463
return _op .min (_op .less (i , B ))
2439
2464
2440
2465
def _outer_body (i , B , C , onnx_out , nms_size_out , out ):
2466
+ # Outer loop just calls inner loop
2441
2467
init_count = _op .const (np .array ([0 ]), dtype = "int64" )
2442
2468
inner_loop_vals = inner_loop (i , init_count , C , onnx_out , nms_size_out , out )
2443
2469
return i + one , B , C , onnx_out , nms_size_out , _expr .TupleGetItem (inner_loop_vals , 5 )
2444
2470
2471
+ # Create the second loop
2445
2472
outer_loop = _loops .while_loop (
2446
2473
_outer_cond , [i , B , C , onnx_out , nms_size_out , out ], _outer_body
2447
2474
)
2448
2475
2476
+ # Call the first loop, perform NMS
2449
2477
B , C , S = _op .split (_op .shape_of (scores , dtype = "int64" ), 3 )
2450
2478
init_count = _op .const (np .array ([0 ]), dtype = "int64" )
2451
- init_onnx_out = _op .const ([], dtype = "int64" )
2452
- init_onnx_out = _op .broadcast_to (init_onnx_out , _op .concatenate ([B , zero , S , three ], 0 ))
2453
- init_nms_size_out = _op .const ([], dtype = "int64" )
2454
- init_nms_size_out = _op .broadcast_to (init_nms_size_out , _op .concatenate ([B , zero , one ], 0 ))
2479
+ init_onnx_out = _op .const ([1 ], dtype = "int64" )
2480
+ init_onnx_out = _op .broadcast_to (init_onnx_out , _op .concatenate ([B , one , S , three ], 0 ))
2481
+ init_nms_size_out = _op .const ([1 ], dtype = "int64" )
2482
+ init_nms_size_out = _op .broadcast_to (init_nms_size_out , _op .concatenate ([B , one , one ], 0 ))
2455
2483
loop_vals = first_loop (
2456
2484
init_count ,
2457
2485
scores ,
@@ -2468,9 +2496,11 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out):
2468
2496
onnx_output = _expr .TupleGetItem (loop_vals , 9 )
2469
2497
nms_size_output = _expr .TupleGetItem (loop_vals , 10 )
2470
2498
2499
+ # Call the second loop, rework outputs into correct form
2471
2500
init_count = _op .const (np .array ([0 ]).astype ("int64" ), dtype = "int64" )
2472
2501
init_out = _op .const (np .array ([]).reshape ([0 , 3 ]).astype ("int64" ), dtype = "int64" )
2473
2502
loop_vals = outer_loop (init_count , B , C , onnx_output , nms_size_output , init_out )
2503
+
2474
2504
return _expr .TupleGetItem (loop_vals , 5 )
2475
2505
2476
2506
0 commit comments