@@ -512,51 +512,62 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
512
512
513
513
with ib .new_scope ():
514
514
nthread_by = batch_size
515
+ nthread_tx = max_threads
516
+ nthread_bx = ceil_div (num_anchors , max_threads )
517
+
515
518
by = te .thread_axis ("blockIdx.y" )
519
+ tx = te .thread_axis ("threadIdx.x" )
520
+ bx = te .thread_axis ("blockIdx.x" )
516
521
ib .scope_attr (by , "thread_extent" , nthread_by )
522
+ ib .scope_attr (by , "thread_extent" , nthread_by )
523
+ ib .scope_attr (tx , "thread_extent" , nthread_tx )
524
+
517
525
i = by
526
+ k = bx * nthread_tx + tx
518
527
base_idx = i * num_anchors * box_data_length
519
528
num_valid_boxes_local = ib .allocate (
520
529
"int32" , (1 ,), name = "num_valid_boxes_local" , scope = "local"
521
530
)
522
531
num_valid_boxes_local [0 ] = 0
523
532
524
533
def nms_inner_loop (ib , j ):
525
- offset_j = j * box_data_length
534
+ # box j is valid, invalidate other boxes that overlap with j above iou_threshold
526
535
527
- with ib .for_range (0 , j ) as k :
528
- offset_k = k * box_data_length
529
-
530
- with ib .if_scope (
531
- tvm .tir .all (
532
- out [base_idx + offset_j + score_index ] > - 1.0 , # if already surpressed
533
- out [base_idx + offset_k + score_index ] > 0 ,
534
- tvm .tir .any (id_index < 0 , out [base_idx + offset_k + id_index ] >= 0 ),
535
- tvm .tir .any (
536
- force_suppress > 0 ,
537
- id_index < 0 ,
538
- out [base_idx + offset_k + id_index ]
539
- == out [base_idx + offset_j + id_index ],
540
- ),
541
- )
542
- ):
543
- iou = calculate_overlap (
544
- out ,
545
- base_idx + offset_j + coord_start ,
546
- base_idx + offset_k + coord_start ,
547
- )
548
- with ib .if_scope (iou >= iou_threshold ):
549
- out [base_idx + offset_j + score_index ] = - 1.0
550
- with ib .if_scope (id_index >= 0 ):
551
- out [base_idx + offset_j + id_index ] = - 1.0
552
-
553
- # Has the box j survived IOU tests?
554
- with ib .if_scope (out [base_idx + offset_j + score_index ] > - 1.0 ):
555
- # When return_indices is False, no need to populate box_indices
556
- if return_indices :
536
+ # When return_indices is False, no need to populate box_indices
537
+ if return_indices :
538
+ # Only one thread needs to this write
539
+ with ib .if_scope (k == 0 ):
557
540
orig_idx = sorted_index [i * num_anchors + j ]
558
541
box_indices [i , num_valid_boxes_local [0 ]] = indices [i , orig_idx ]
559
- num_valid_boxes_local [0 ] += 1
542
+
543
+ num_valid_boxes_local [0 ] += 1
544
+
545
+ offset_j = j * box_data_length
546
+ offset_k = k * box_data_length
547
+
548
+ with ib .if_scope (
549
+ tvm .tir .all (
550
+ j < k ,
551
+ out [base_idx + offset_k + score_index ] > 0 ,
552
+ tvm .tir .any (id_index < 0 , out [base_idx + offset_k + id_index ] >= 0 ),
553
+ tvm .tir .any (
554
+ force_suppress > 0 ,
555
+ id_index < 0 ,
556
+ out [base_idx + offset_k + id_index ] == out [base_idx + offset_j + id_index ],
557
+ ),
558
+ )
559
+ ):
560
+ iou = calculate_overlap (
561
+ out ,
562
+ base_idx + offset_j + coord_start ,
563
+ base_idx + offset_k + coord_start ,
564
+ )
565
+ with ib .if_scope (iou >= iou_threshold ):
566
+ out [base_idx + offset_k + score_index ] = - 1.0
567
+ with ib .if_scope (id_index >= 0 ):
568
+ out [base_idx + offset_k + id_index ] = - 1.0
569
+
570
+ ib .emit (tvm .tir .Call (None , "tir.tvm_storage_sync" , tvm .runtime .convert (["shared" ])))
560
571
561
572
if isinstance (max_output_size , int ):
562
573
max_output_size = tvm .tir .const (max_output_size )
@@ -565,7 +576,12 @@ def nms_inner_loop(ib, j):
565
576
# Apply nms
566
577
with ib .for_range (0 , valid_count [i ]) as j :
567
578
with ib .if_scope (
568
- tvm .tir .any (id_index < 0 , out [base_idx + j * box_data_length + id_index ] >= 0 )
579
+ tvm .tir .all (
580
+ out [base_idx + (j * box_data_length ) + score_index ] > - 1.0 ,
581
+ tvm .tir .any (
582
+ id_index < 0 , out [base_idx + j * box_data_length + id_index ] >= 0
583
+ ),
584
+ )
569
585
):
570
586
with ib .if_scope (max_output_size > 0 ):
571
587
# No need to do more iteration if we already reach max_output_size boxes
0 commit comments