Skip to content

Commit c75c6ef

Browse files
committed
make NMS inner loop parallel
1 parent 3d8fd2a commit c75c6ef

File tree

1 file changed

+49
-33
lines changed

1 file changed

+49
-33
lines changed

python/tvm/topi/cuda/nms.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -512,51 +512,62 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
512512

513513
with ib.new_scope():
514514
nthread_by = batch_size
515+
nthread_tx = max_threads
516+
nthread_bx = ceil_div(num_anchors, max_threads)
517+
515518
by = te.thread_axis("blockIdx.y")
519+
tx = te.thread_axis("threadIdx.x")
520+
bx = te.thread_axis("blockIdx.x")
516521
ib.scope_attr(by, "thread_extent", nthread_by)
522+
ib.scope_attr(by, "thread_extent", nthread_by)
523+
ib.scope_attr(tx, "thread_extent", nthread_tx)
524+
517525
i = by
526+
k = bx * nthread_tx + tx
518527
base_idx = i * num_anchors * box_data_length
519528
num_valid_boxes_local = ib.allocate(
520529
"int32", (1,), name="num_valid_boxes_local", scope="local"
521530
)
522531
num_valid_boxes_local[0] = 0
523532

524533
def nms_inner_loop(ib, j):
525-
offset_j = j * box_data_length
534+
# box j is valid, invalidate other boxes that overlap with j above iou_threshold
526535

527-
with ib.for_range(0, j) as k:
528-
offset_k = k * box_data_length
529-
530-
with ib.if_scope(
531-
tvm.tir.all(
532-
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
533-
out[base_idx + offset_k + score_index] > 0,
534-
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
535-
tvm.tir.any(
536-
force_suppress > 0,
537-
id_index < 0,
538-
out[base_idx + offset_k + id_index]
539-
== out[base_idx + offset_j + id_index],
540-
),
541-
)
542-
):
543-
iou = calculate_overlap(
544-
out,
545-
base_idx + offset_j + coord_start,
546-
base_idx + offset_k + coord_start,
547-
)
548-
with ib.if_scope(iou >= iou_threshold):
549-
out[base_idx + offset_j + score_index] = -1.0
550-
with ib.if_scope(id_index >= 0):
551-
out[base_idx + offset_j + id_index] = -1.0
552-
553-
# Has the box j survived IOU tests?
554-
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
555-
# When return_indices is False, no need to populate box_indices
556-
if return_indices:
536+
# When return_indices is False, no need to populate box_indices
537+
if return_indices:
538+
# Only one thread needs to this write
539+
with ib.if_scope(k == 0):
557540
orig_idx = sorted_index[i * num_anchors + j]
558541
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
559-
num_valid_boxes_local[0] += 1
542+
543+
num_valid_boxes_local[0] += 1
544+
545+
offset_j = j * box_data_length
546+
offset_k = k * box_data_length
547+
548+
with ib.if_scope(
549+
tvm.tir.all(
550+
j < k,
551+
out[base_idx + offset_k + score_index] > 0,
552+
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
553+
tvm.tir.any(
554+
force_suppress > 0,
555+
id_index < 0,
556+
out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index],
557+
),
558+
)
559+
):
560+
iou = calculate_overlap(
561+
out,
562+
base_idx + offset_j + coord_start,
563+
base_idx + offset_k + coord_start,
564+
)
565+
with ib.if_scope(iou >= iou_threshold):
566+
out[base_idx + offset_k + score_index] = -1.0
567+
with ib.if_scope(id_index >= 0):
568+
out[base_idx + offset_k + id_index] = -1.0
569+
570+
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
560571

561572
if isinstance(max_output_size, int):
562573
max_output_size = tvm.tir.const(max_output_size)
@@ -565,7 +576,12 @@ def nms_inner_loop(ib, j):
565576
# Apply nms
566577
with ib.for_range(0, valid_count[i]) as j:
567578
with ib.if_scope(
568-
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
579+
tvm.tir.all(
580+
out[base_idx + (j * box_data_length) + score_index] > -1.0,
581+
tvm.tir.any(
582+
id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0
583+
),
584+
)
569585
):
570586
with ib.if_scope(max_output_size > 0):
571587
# No need to do more iteration if we already reach max_output_size boxes

0 commit comments

Comments
 (0)