@@ -2303,6 +2303,274 @@ def _impl_v1(cls, inputs, attr, params):
2303
2303
return _expr .If (cond , then_expr , else_expr )
2304
2304
2305
2305
2306
+ class NonMaxSuppression (OnnxOpConverter ):
2307
+ """Operator converter for NonMaxSuppression."""
2308
+
2309
+ @classmethod
2310
+ def _impl_v10 (cls , inputs , attr , params ):
2311
+ """
2312
+ High level note: ONNX implements what TF calls combined_non_max_suppression
2313
+ It passes in scores for each box for every class in the output and expects boxes to be
2314
+ analyzed for each class independently
2315
+
2316
+ It also asks for the data to be returned in a particular format.
2317
+
2318
+ To support these, we implement a series of lops:
2319
+ The first loop splits over class number, performs NMS, and collects the outputs.
2320
+ The second (nested) loop takes the outputs and transforms them into the format ONNX wants
2321
+ """
2322
+ # Get parameter values
2323
+ boxes = inputs [0 ]
2324
+ scores = inputs [1 ]
2325
+ max_output_boxes_per_class = inputs [2 ]
2326
+ iou_threshold = inputs [3 ]
2327
+ score_threshold = inputs [4 ]
2328
+
2329
+ dtype = infer_type (boxes ).checked_type .dtype
2330
+
2331
+ if "center_point_box" in attr :
2332
+ assert (
2333
+ attr ["center_point_box" ] == 0
2334
+ ), "Only support center_point_box = 0 in onnx importer right now"
2335
+
2336
+ if iou_threshold is None :
2337
+ iou_threshold = _expr .const (0.0 , dtype = "float32" )
2338
+ if score_threshold is None :
2339
+ score_threshold = _expr .const (0.0 , dtype = "float32" )
2340
+
2341
+ def conditionally_squeeze_scalar (x ):
2342
+ rank = len (infer_shape (x ))
2343
+ assert rank <= 1 , "nms thresholds must be scalars"
2344
+ if rank == 1 :
2345
+ return _op .squeeze (x , [0 ])
2346
+ return x
2347
+
2348
+ max_output_boxes_per_class = conditionally_squeeze_scalar (max_output_boxes_per_class )
2349
+ iou_threshold = conditionally_squeeze_scalar (iou_threshold )
2350
+ score_threshold = conditionally_squeeze_scalar (score_threshold )
2351
+
2352
+ ## prepare utility constants
2353
+ zero = _op .const (np .array ([0 ]), dtype = "int64" )
2354
+ one = _op .const (np .array ([1 ]), dtype = "int64" )
2355
+ two = _op .const (np .array ([2 ]), dtype = "int64" )
2356
+ three = _op .const (np .array ([3 ]), dtype = "int64" )
2357
+ three_ones = _op .const (np .array ([1 , 1 , 1 ]), dtype = "int64" )
2358
+ four_ones = _op .const (np .array ([1 , 1 , 1 , 1 ]), dtype = "int64" )
2359
+
2360
+ ## First loop: split by class and perform NMS
2361
+ # Create Loop Vars
2362
+ i = _expr .var ("i" , shape = (1 ,), dtype = "int64" )
2363
+ scores_var = _expr .var ("scores_var" , shape = (_ty .Any (), _ty .Any (), _ty .Any ()), dtype = dtype )
2364
+ boxes_var = _expr .var ("boxes_var" , shape = (_ty .Any (), _ty .Any (), 4 ), dtype = dtype )
2365
+ max_output_boxes_per_class_var = _expr .var (
2366
+ "max_output_boxes_per_class_var" , shape = (), dtype = "int64"
2367
+ )
2368
+ iou_threshold_var = _expr .var ("iou_threshold_var" , shape = (), dtype = "float32" )
2369
+ score_threshold_var = _expr .var ("score_threshold_var" , shape = (), dtype = "float32" )
2370
+ B = _expr .var ("B" , shape = (1 ,), dtype = "int64" )
2371
+ C = _expr .var ("C" , shape = (1 ,), dtype = "int64" )
2372
+ S = _expr .var ("S" , shape = (1 ,), dtype = "int64" )
2373
+ # Outputs of first loop should be padded nms values shape (B, C, S, 3)
2374
+ onnx_out = _expr .var ("onnx_out" , shape = (_ty .Any (), _ty .Any (), _ty .Any (), 3 ), dtype = "int64" )
2375
+ # and sizes of valid outputs, shape (B, C, 1)
2376
+ nms_size_out = _expr .var ("nms_size_out" , shape = (_ty .Any (), _ty .Any (), 1 ), dtype = "int64" )
2377
+
2378
+ def _first_cond (
2379
+ i ,
2380
+ scores ,
2381
+ boxes ,
2382
+ B ,
2383
+ C ,
2384
+ S ,
2385
+ max_output_boxes_per_class ,
2386
+ iou_threshold ,
2387
+ score_threshold ,
2388
+ onnx_out ,
2389
+ nms_size_out ,
2390
+ ):
2391
+ # Loop over classes, end when i == C
2392
+ return _op .min (_op .less (i , C ))
2393
+
2394
+ def _first_body (
2395
+ i ,
2396
+ scores ,
2397
+ boxes ,
2398
+ B ,
2399
+ C ,
2400
+ S ,
2401
+ max_output_boxes_per_class ,
2402
+ iou_threshold ,
2403
+ score_threshold ,
2404
+ onnx_out ,
2405
+ nms_size_out ,
2406
+ ):
2407
+ # slice to get current class
2408
+ begin = _op .concatenate ([zero , i , zero ], axis = 0 )
2409
+ end = _op .concatenate ([B , i + one , S ], axis = 0 )
2410
+ class_scores = _op .strided_slice (scores , begin , end , three_ones )
2411
+ class_scores = _op .expand_dims (_op .squeeze (class_scores , [1 ]), - 1 , 1 )
2412
+ # combine scores and boxes
2413
+ data = _op .concatenate ([class_scores , boxes ], axis = - 1 )
2414
+
2415
+ # get valid counts
2416
+ ct , data , indices = _op .vision .get_valid_counts (
2417
+ data , score_threshold = score_threshold , id_index = - 1 , score_index = 0
2418
+ )
2419
+ # reason why using get_valid_counts is for inference performance
2420
+ # ONNX NMS doesn't have parameter top_k
2421
+ top_k = - 1
2422
+ # ONNX doesn't have class id for nms input
2423
+ score_index = 0
2424
+ # perform nms on current class
2425
+ nms_ret = _op .vision .non_max_suppression (
2426
+ data = data ,
2427
+ valid_count = ct ,
2428
+ indices = indices ,
2429
+ max_output_size = max_output_boxes_per_class ,
2430
+ iou_threshold = iou_threshold ,
2431
+ force_suppress = True ,
2432
+ top_k = top_k ,
2433
+ coord_start = 1 ,
2434
+ score_index = score_index ,
2435
+ id_index = - 1 ,
2436
+ return_indices = True ,
2437
+ invalid_to_bottom = False ,
2438
+ )
2439
+ # partially prepare ONNX output format by labeling batch_num, class_id
2440
+ nms_padded_out = _op .expand_dims (nms_ret [0 ], - 1 , 1 )
2441
+ batch_num = _op .expand_dims (_op .arange (_op .squeeze (B , [0 ]), dtype = "int64" ), - 1 , 1 )
2442
+ batch_num = _op .broadcast_to (batch_num , _op .shape_of (nms_ret [0 ], dtype = "int64" ))
2443
+ batch_num = _op .expand_dims (batch_num , - 1 , 1 )
2444
+ class_num = _op .broadcast_to (i , _op .shape_of (nms_padded_out , dtype = "int64" ))
2445
+ new_onnx_out = _op .concatenate (
2446
+ [batch_num , class_num , _op .cast (nms_padded_out , "int64" )], - 1
2447
+ )
2448
+ new_onnx_out = _op .expand_dims (new_onnx_out , 1 , 1 )
2449
+ # store valid nms outputs for this class
2450
+ nms_size = _op .cast (nms_ret [1 ], "int64" )
2451
+ nms_size = _op .expand_dims (nms_size , 1 , 1 )
2452
+ return [
2453
+ i + one ,
2454
+ scores ,
2455
+ boxes ,
2456
+ B ,
2457
+ C ,
2458
+ S ,
2459
+ max_output_boxes_per_class ,
2460
+ iou_threshold ,
2461
+ score_threshold ,
2462
+ _op .concatenate ([onnx_out , new_onnx_out ], axis = 1 ),
2463
+ _op .concatenate ([nms_size_out , nms_size ], axis = 1 ),
2464
+ ]
2465
+
2466
+ # create the first loop
2467
+ first_loop = _loops .while_loop (
2468
+ _first_cond ,
2469
+ [
2470
+ i ,
2471
+ scores_var ,
2472
+ boxes_var ,
2473
+ B ,
2474
+ C ,
2475
+ S ,
2476
+ max_output_boxes_per_class_var ,
2477
+ iou_threshold_var ,
2478
+ score_threshold_var ,
2479
+ onnx_out ,
2480
+ nms_size_out ,
2481
+ ],
2482
+ _first_body ,
2483
+ )
2484
+
2485
+ ## Second loop slices outputs of the first loop for valid boxes and
2486
+ ## concats in the order ONNX wants
2487
+ # Second inner Loop Vars
2488
+ i = _expr .var ("i" , shape = (1 ,), dtype = "int64" )
2489
+ j = _expr .var ("j" , shape = (1 ,), dtype = "int64" )
2490
+ B = _expr .var ("B" , shape = (1 ,), dtype = "int64" )
2491
+ C = _expr .var ("C" , shape = (1 ,), dtype = "int64" )
2492
+ # Outputs of first loop should be padded nms values shape (B, C, 3)
2493
+ onnx_out = _expr .var ("onnx_out" , shape = (_ty .Any (), _ty .Any (), _ty .Any (), 3 ), dtype = "int64" )
2494
+ # and sizes of valid outputs, shape (B, C, 1)
2495
+ nms_size_out = _expr .var ("nms_size_out" , shape = (_ty .Any (), _ty .Any (), 1 ), dtype = "int64" )
2496
+ out = _expr .var ("out" , shape = (_ty .Any (), 3 ), dtype = "int64" )
2497
+
2498
+ def _inner_cond (i , j , C , onnx_out , nms_size , out ):
2499
+ # inner loop over number of classes
2500
+ return _op .min (_op .less (j , C ))
2501
+
2502
+ def _inner_body (i , j , C , onnx_out , nms_size , out ):
2503
+ # slice to get current batch and class for valid box indicator
2504
+ start = _op .concatenate ([i , j + one , zero ], axis = 0 )
2505
+ end = _op .concatenate ([i + one , j + two , one ], axis = 0 )
2506
+ num_valid_boxes = _op .reshape (_op .strided_slice (nms_size , start , end , three_ones ), [1 ])
2507
+ # slice to get current batch, class, and valid outputs
2508
+ start = _op .concatenate ([i , j + one , zero , zero ], axis = 0 )
2509
+ end = _op .concatenate ([i + one , j + two , num_valid_boxes , three ], axis = 0 )
2510
+ new_out = _op .squeeze (_op .strided_slice (onnx_out , start , end , four_ones ), [0 , 1 ])
2511
+ return i , j + one , C , onnx_out , nms_size , _op .concatenate ([out , new_out ], axis = 0 )
2512
+
2513
+ inner_loop = _loops .while_loop (
2514
+ _inner_cond , [i , j , C , onnx_out , nms_size_out , out ], _inner_body
2515
+ )
2516
+
2517
+ # Second Outer Loop Vars
2518
+ i = _expr .var ("i" , shape = (1 ,), dtype = "int64" )
2519
+ j = _expr .var ("j" , shape = (1 ,), dtype = "int64" )
2520
+ B = _expr .var ("B" , shape = (1 ,), dtype = "int64" )
2521
+ C = _expr .var ("C" , shape = (1 ,), dtype = "int64" )
2522
+ # Outputs of first loop should be padded nms values shape (B, C, 3)
2523
+ onnx_out = _expr .var ("onnx_out" , shape = (_ty .Any (), _ty .Any (), _ty .Any (), 3 ), dtype = "int64" )
2524
+ # and sizes of valid outputs, shape (B, C, 1)
2525
+ nms_size_out = _expr .var ("nms_size_out" , shape = (_ty .Any (), _ty .Any (), 1 ), dtype = "int64" )
2526
+ out = _expr .var ("out" , shape = (_ty .Any (), 3 ), dtype = "int64" )
2527
+
2528
+ def _outer_cond (i , B , C , onnx_out , nms_size_out , out ):
2529
+ # Outer loop is over batch size
2530
+ return _op .min (_op .less (i , B ))
2531
+
2532
+ def _outer_body (i , B , C , onnx_out , nms_size_out , out ):
2533
+ # Outer loop just calls inner loop
2534
+ init_count = _op .const (np .array ([0 ]), dtype = "int64" )
2535
+ inner_loop_vals = inner_loop (i , init_count , C , onnx_out , nms_size_out , out )
2536
+ return i + one , B , C , onnx_out , nms_size_out , _expr .TupleGetItem (inner_loop_vals , 5 )
2537
+
2538
+ # Create the second loop
2539
+ outer_loop = _loops .while_loop (
2540
+ _outer_cond , [i , B , C , onnx_out , nms_size_out , out ], _outer_body
2541
+ )
2542
+
2543
+ # Call the first loop, perform NMS
2544
+ B , C , S = _op .split (_op .shape_of (scores , dtype = "int64" ), 3 )
2545
+ init_count = _op .const (np .array ([0 ]), dtype = "int64" )
2546
+ init_onnx_out = _op .const ([1 ], dtype = "int64" )
2547
+ init_onnx_out = _op .broadcast_to (init_onnx_out , _op .concatenate ([B , one , S , three ], 0 ))
2548
+ init_nms_size_out = _op .const ([1 ], dtype = "int64" )
2549
+ init_nms_size_out = _op .broadcast_to (init_nms_size_out , _op .concatenate ([B , one , one ], 0 ))
2550
+ loop_vals = first_loop (
2551
+ init_count ,
2552
+ scores ,
2553
+ boxes ,
2554
+ B ,
2555
+ C ,
2556
+ S ,
2557
+ max_output_boxes_per_class ,
2558
+ iou_threshold ,
2559
+ score_threshold ,
2560
+ init_onnx_out ,
2561
+ init_nms_size_out ,
2562
+ )
2563
+ onnx_output = _expr .TupleGetItem (loop_vals , 9 )
2564
+ nms_size_output = _expr .TupleGetItem (loop_vals , 10 )
2565
+
2566
+ # Call the second loop, rework outputs into correct form
2567
+ init_count = _op .const (np .array ([0 ]).astype ("int64" ), dtype = "int64" )
2568
+ init_out = _op .const (np .array ([]).reshape ([0 , 3 ]).astype ("int64" ), dtype = "int64" )
2569
+ loop_vals = outer_loop (init_count , B , C , onnx_output , nms_size_output , init_out )
2570
+
2571
+ return _expr .TupleGetItem (loop_vals , 5 )
2572
+
2573
+
2306
2574
# compatible operators that do NOT require any conversion.
2307
2575
_identity_list = []
2308
2576
@@ -2415,6 +2683,7 @@ def _get_convert_map(opset):
2415
2683
# defs/vision
2416
2684
"MaxRoiPool" : MaxRoiPool .get_converter (opset ),
2417
2685
"RoiAlign" : RoiAlign .get_converter (opset ),
2686
+ "NonMaxSuppression" : NonMaxSuppression .get_converter (opset ),
2418
2687
# defs/reduction
2419
2688
"ReduceMax" : ReduceMax .get_converter (opset ),
2420
2689
"ReduceMin" : ReduceMin .get_converter (opset ),
0 commit comments