@@ -514,7 +514,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
514
514
indices = ib .buffer_ptr (indices )
515
515
out = ib .buffer_ptr (out )
516
516
box_indices = ib .buffer_ptr (box_indices )
517
- num_valid_boxes = ib .allocate ("int32" , (1 ,), name = "num_valid_boxes" , scope = "local" )
518
517
519
518
if isinstance (iou_threshold , float ):
520
519
iou_threshold = tvm .tir .FloatImm ("float32" , iou_threshold )
@@ -527,86 +526,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
527
526
max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
528
527
529
528
with ib .new_scope ():
530
- bx = te .thread_axis ("blockIdx.x" )
531
- ib .scope_attr (bx , "thread_extent" , 1 )
532
-
533
- with ib .for_range (0 , batch_size ) as i :
534
- base_idx = i * num_anchors * box_data_length
535
- with ib .if_scope (tvm .tir .all (iou_threshold > 0 , valid_count [i ] > 0 )):
536
- # Reorder output
537
- nkeep = if_then_else (
538
- tvm .tir .all (top_k > 0 , top_k < valid_count [i ]), top_k , valid_count [i ]
539
- )
540
- with ib .for_range (0 , nkeep ) as j :
529
+ nthread_by = batch_size
530
+ by = te .thread_axis ("blockIdx.y" )
531
+ ib .scope_attr (by , "thread_extent" , nthread_by )
532
+ i = by
533
+ base_idx = i * num_anchors * box_data_length
534
+ with ib .if_scope (tvm .tir .all (iou_threshold > 0 , valid_count [i ] > 0 )):
535
+ # Reorder output
536
+ nkeep = if_then_else (
537
+ tvm .tir .all (top_k > 0 , top_k < valid_count [i ]), top_k , valid_count [i ]
538
+ )
539
+ with ib .for_range (0 , nkeep ) as j :
540
+ with ib .for_range (0 , box_data_length ) as k :
541
+ out [(base_idx + j * box_data_length + k )] = data [
542
+ (base_idx + sorted_index [i * num_anchors + j ] * box_data_length + k )
543
+ ]
544
+ box_indices [i * num_anchors + j ] = sorted_index [i * num_anchors + j ]
545
+ with ib .if_scope (tvm .tir .all (top_k > 0 , top_k < valid_count [i ])):
546
+ with ib .for_range (0 , valid_count [i ] - nkeep ) as j :
541
547
with ib .for_range (0 , box_data_length ) as k :
542
- out [(base_idx + j * box_data_length + k )] = data [
543
- (base_idx + sorted_index [i * num_anchors + j ] * box_data_length + k )
544
- ]
545
- box_indices [i * num_anchors + j ] = sorted_index [i * num_anchors + j ]
546
- with ib .if_scope (tvm .tir .all (top_k > 0 , top_k < valid_count [i ])):
547
- with ib .for_range (0 , valid_count [i ] - nkeep ) as j :
548
- with ib .for_range (0 , box_data_length ) as k :
549
- out [(base_idx + (j + nkeep ) * box_data_length + k )] = - 1.0
550
- box_indices [i * num_anchors + (j + nkeep )] = - 1
551
- # Apply nms
552
- with ib .for_range (0 , valid_count [i ]) as j :
553
- with ib .for_range (0 , j ) as k :
554
- offset_k = k * box_data_length
548
+ out [(base_idx + (j + nkeep ) * box_data_length + k )] = - 1.0
549
+ box_indices [i * num_anchors + (j + nkeep )] = - 1
550
+ with ib .new_scope ():
551
+ nthread_by = batch_size
552
+ by = te .thread_axis ("blockIdx.y" )
553
+ ib .scope_attr (by , "thread_extent" , nthread_by )
554
+ i = by
555
+ base_idx = i * num_anchors * box_data_length
556
+ with ib .if_scope (tvm .tir .all (iou_threshold > 0 , valid_count [i ] > 0 )):
557
+ # Apply nms
558
+ with ib .for_range (0 , valid_count [i ]) as j :
559
+ with ib .for_range (0 , j ) as k :
560
+ offset_k = k * box_data_length
561
+ with ib .if_scope (
562
+ tvm .tir .all (
563
+ out [base_idx + offset_k + score_index ] > 0 ,
564
+ tvm .tir .any (id_index < 0 , out [base_idx + offset_k + id_index ] >= 0 ),
565
+ )
566
+ ):
567
+ offset_j = j * box_data_length
555
568
with ib .if_scope (
556
569
tvm .tir .all (
570
+ j > k ,
557
571
out [base_idx + offset_k + score_index ] > 0 ,
558
- tvm .tir .any (id_index < 0 , out [base_idx + offset_k + id_index ] >= 0 ),
572
+ tvm .tir .any (id_index < 0 , out [base_idx + offset_j + id_index ] >= 0 ),
573
+ tvm .tir .any (
574
+ force_suppress > 0 ,
575
+ id_index < 0 ,
576
+ out [base_idx + offset_k + id_index ]
577
+ == out [base_idx + offset_j + id_index ],
578
+ ),
559
579
)
560
580
):
561
- offset_j = j * box_data_length
562
- with ib .if_scope (
563
- tvm .tir .all (
564
- j > k ,
565
- out [base_idx + offset_k + score_index ] > 0 ,
566
- tvm .tir .any (
567
- id_index < 0 , out [base_idx + offset_j + id_index ] >= 0
568
- ),
569
- tvm .tir .any (
570
- force_suppress > 0 ,
571
- id_index < 0 ,
572
- out [base_idx + offset_k + id_index ]
573
- == out [base_idx + offset_j + id_index ],
574
- ),
575
- )
576
- ):
577
- iou = calculate_overlap (
578
- out ,
579
- base_idx + offset_j + coord_start ,
580
- base_idx + offset_k + coord_start ,
581
- )
582
- with ib .if_scope (iou >= iou_threshold ):
583
- out [base_idx + offset_j + score_index ] = - 1.0
584
- with ib .if_scope (id_index >= 0 ):
585
- out [base_idx + offset_j + id_index ] = - 1.0
586
- box_indices [i * num_anchors + j ] = - 1
587
- with ib .else_scope ():
588
- with ib .for_range (0 , valid_count [i ]) as j :
589
- offset_j = j * box_data_length
590
- with ib .for_range (0 , box_data_length ) as k :
591
- out [(base_idx + offset_j + k )] = data [base_idx + offset_j + k ]
592
- box_indices [i * num_anchors + j ] = j
593
- # Set invalid entry to be -1
594
- with ib .for_range (0 , num_anchors - valid_count [i ]) as j :
595
- with ib .for_range (0 , box_data_length ) as k :
596
- out [base_idx + (j + valid_count [i ]) * box_data_length + k ] = - 1.0
597
- box_indices [i * num_anchors + j + valid_count [i ]] = - 1
598
- # Only return max_output_size number of valid boxes
599
- num_valid_boxes [0 ] = 0
600
- with ib .if_scope (max_output_size > 0 ):
601
- with ib .for_range (0 , valid_count [i ]) as j :
602
- offset_j = j * box_data_length
603
- with ib .if_scope (out [base_idx + offset_j ] >= 0 ):
604
- with ib .if_scope (num_valid_boxes [0 ] == max_output_size ):
605
- with ib .for_range (0 , box_data_length ) as k :
606
- out [base_idx + offset_j + k ] = - 1.0
607
- box_indices [i * num_anchors + j ] = - 1
608
- with ib .else_scope ():
609
- num_valid_boxes [0 ] += 1
581
+ iou = calculate_overlap (
582
+ out ,
583
+ base_idx + offset_j + coord_start ,
584
+ base_idx + offset_k + coord_start ,
585
+ )
586
+ with ib .if_scope (iou >= iou_threshold ):
587
+ out [base_idx + offset_j + score_index ] = - 1.0
588
+ with ib .if_scope (id_index >= 0 ):
589
+ out [base_idx + offset_j + id_index ] = - 1.0
590
+ box_indices [i * num_anchors + j ] = - 1
591
+ with ib .new_scope ():
592
+ nthread_tx = max_threads
593
+ nthread_bx = num_anchors // max_threads + 1
594
+ nthread_by = batch_size
595
+ nthread_bz = box_data_length
596
+ tx = te .thread_axis ("threadIdx.x" )
597
+ bx = te .thread_axis ("blockIdx.x" )
598
+ by = te .thread_axis ("blockIdx.y" )
599
+ bz = te .thread_axis ("blockIdx.z" )
600
+ ib .scope_attr (tx , "thread_extent" , nthread_tx )
601
+ ib .scope_attr (bx , "thread_extent" , nthread_bx )
602
+ ib .scope_attr (by , "thread_extent" , nthread_by )
603
+ ib .scope_attr (bz , "thread_extent" , nthread_bz )
604
+ tid = bx * max_threads + tx
605
+ i = by
606
+ j = tid
607
+ k = bz
608
+ base_idx = i * num_anchors * box_data_length
609
+ with ib .if_scope (tvm .tir .all (iou_threshold > 0 , valid_count [i ] > 0 )):
610
+ pass
611
+ with ib .else_scope ():
612
+ with ib .if_scope (j < valid_count [i ]):
613
+ offset_j = j * box_data_length
614
+ out [(base_idx + offset_j + k )] = data [base_idx + offset_j + k ]
615
+ box_indices [i * num_anchors + j ] = j
616
+
617
+ with ib .new_scope ():
618
+ num_valid_boxes = ib .allocate ("int32" , (1 ,), name = "num_valid_boxes" , scope = "local" )
619
+ bx = te .thread_axis ("blockIdx.x" )
620
+ ib .scope_attr (bx , "thread_extent" , batch_size )
621
+ i = bx
622
+ base_idx = i * num_anchors * box_data_length
623
+ # Set invalid entry to be -1
624
+ with ib .for_range (0 , num_anchors - valid_count [i ]) as j :
625
+ with ib .for_range (0 , box_data_length ) as k :
626
+ out [base_idx + (j + valid_count [i ]) * box_data_length + k ] = - 1.0
627
+ box_indices [i * num_anchors + j + valid_count [i ]] = - 1
628
+ # Only return max_output_size number of valid boxes
629
+ num_valid_boxes [0 ] = 0
630
+ with ib .if_scope (max_output_size > 0 ):
631
+ with ib .for_range (0 , valid_count [i ]) as j :
632
+ offset_j = j * box_data_length
633
+ with ib .if_scope (out [base_idx + offset_j ] >= 0 ):
634
+ with ib .if_scope (num_valid_boxes [0 ] == max_output_size ):
635
+ with ib .for_range (0 , box_data_length ) as k :
636
+ out [base_idx + offset_j + k ] = - 1.0
637
+ box_indices [i * num_anchors + j ] = - 1
638
+ with ib .else_scope ():
639
+ num_valid_boxes [0 ] += 1
610
640
611
641
if return_indices :
612
642
with ib .new_scope ():
0 commit comments