Skip to content

Commit a3dd35d

Browse files
author
mbrookhart
committed
improve nms parallelization
1 parent 5daff0f commit a3dd35d

File tree

1 file changed

+105
-75
lines changed

1 file changed

+105
-75
lines changed

python/tvm/topi/cuda/nms.py

Lines changed: 105 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
514514
indices = ib.buffer_ptr(indices)
515515
out = ib.buffer_ptr(out)
516516
box_indices = ib.buffer_ptr(box_indices)
517-
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
518517

519518
if isinstance(iou_threshold, float):
520519
iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
@@ -527,86 +526,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
527526
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
528527

529528
with ib.new_scope():
530-
bx = te.thread_axis("blockIdx.x")
531-
ib.scope_attr(bx, "thread_extent", 1)
532-
533-
with ib.for_range(0, batch_size) as i:
534-
base_idx = i * num_anchors * box_data_length
535-
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
536-
# Reorder output
537-
nkeep = if_then_else(
538-
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
539-
)
540-
with ib.for_range(0, nkeep) as j:
529+
nthread_by = batch_size
530+
by = te.thread_axis("blockIdx.y")
531+
ib.scope_attr(by, "thread_extent", nthread_by)
532+
i = by
533+
base_idx = i * num_anchors * box_data_length
534+
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
535+
# Reorder output
536+
nkeep = if_then_else(
537+
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
538+
)
539+
with ib.for_range(0, nkeep) as j:
540+
with ib.for_range(0, box_data_length) as k:
541+
out[(base_idx + j * box_data_length + k)] = data[
542+
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
543+
]
544+
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
545+
with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
546+
with ib.for_range(0, valid_count[i] - nkeep) as j:
541547
with ib.for_range(0, box_data_length) as k:
542-
out[(base_idx + j * box_data_length + k)] = data[
543-
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
544-
]
545-
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
546-
with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
547-
with ib.for_range(0, valid_count[i] - nkeep) as j:
548-
with ib.for_range(0, box_data_length) as k:
549-
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
550-
box_indices[i * num_anchors + (j + nkeep)] = -1
551-
# Apply nms
552-
with ib.for_range(0, valid_count[i]) as j:
553-
with ib.for_range(0, j) as k:
554-
offset_k = k * box_data_length
548+
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
549+
box_indices[i * num_anchors + (j + nkeep)] = -1
550+
with ib.new_scope():
551+
nthread_by = batch_size
552+
by = te.thread_axis("blockIdx.y")
553+
ib.scope_attr(by, "thread_extent", nthread_by)
554+
i = by
555+
base_idx = i * num_anchors * box_data_length
556+
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
557+
# Apply nms
558+
with ib.for_range(0, valid_count[i]) as j:
559+
with ib.for_range(0, j) as k:
560+
offset_k = k * box_data_length
561+
with ib.if_scope(
562+
tvm.tir.all(
563+
out[base_idx + offset_k + score_index] > 0,
564+
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
565+
)
566+
):
567+
offset_j = j * box_data_length
555568
with ib.if_scope(
556569
tvm.tir.all(
570+
j > k,
557571
out[base_idx + offset_k + score_index] > 0,
558-
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
572+
tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0),
573+
tvm.tir.any(
574+
force_suppress > 0,
575+
id_index < 0,
576+
out[base_idx + offset_k + id_index]
577+
== out[base_idx + offset_j + id_index],
578+
),
559579
)
560580
):
561-
offset_j = j * box_data_length
562-
with ib.if_scope(
563-
tvm.tir.all(
564-
j > k,
565-
out[base_idx + offset_k + score_index] > 0,
566-
tvm.tir.any(
567-
id_index < 0, out[base_idx + offset_j + id_index] >= 0
568-
),
569-
tvm.tir.any(
570-
force_suppress > 0,
571-
id_index < 0,
572-
out[base_idx + offset_k + id_index]
573-
== out[base_idx + offset_j + id_index],
574-
),
575-
)
576-
):
577-
iou = calculate_overlap(
578-
out,
579-
base_idx + offset_j + coord_start,
580-
base_idx + offset_k + coord_start,
581-
)
582-
with ib.if_scope(iou >= iou_threshold):
583-
out[base_idx + offset_j + score_index] = -1.0
584-
with ib.if_scope(id_index >= 0):
585-
out[base_idx + offset_j + id_index] = -1.0
586-
box_indices[i * num_anchors + j] = -1
587-
with ib.else_scope():
588-
with ib.for_range(0, valid_count[i]) as j:
589-
offset_j = j * box_data_length
590-
with ib.for_range(0, box_data_length) as k:
591-
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
592-
box_indices[i * num_anchors + j] = j
593-
# Set invalid entry to be -1
594-
with ib.for_range(0, num_anchors - valid_count[i]) as j:
595-
with ib.for_range(0, box_data_length) as k:
596-
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
597-
box_indices[i * num_anchors + j + valid_count[i]] = -1
598-
# Only return max_output_size number of valid boxes
599-
num_valid_boxes[0] = 0
600-
with ib.if_scope(max_output_size > 0):
601-
with ib.for_range(0, valid_count[i]) as j:
602-
offset_j = j * box_data_length
603-
with ib.if_scope(out[base_idx + offset_j] >= 0):
604-
with ib.if_scope(num_valid_boxes[0] == max_output_size):
605-
with ib.for_range(0, box_data_length) as k:
606-
out[base_idx + offset_j + k] = -1.0
607-
box_indices[i * num_anchors + j] = -1
608-
with ib.else_scope():
609-
num_valid_boxes[0] += 1
581+
iou = calculate_overlap(
582+
out,
583+
base_idx + offset_j + coord_start,
584+
base_idx + offset_k + coord_start,
585+
)
586+
with ib.if_scope(iou >= iou_threshold):
587+
out[base_idx + offset_j + score_index] = -1.0
588+
with ib.if_scope(id_index >= 0):
589+
out[base_idx + offset_j + id_index] = -1.0
590+
box_indices[i * num_anchors + j] = -1
591+
with ib.new_scope():
592+
nthread_tx = max_threads
593+
nthread_bx = num_anchors // max_threads + 1
594+
nthread_by = batch_size
595+
nthread_bz = box_data_length
596+
tx = te.thread_axis("threadIdx.x")
597+
bx = te.thread_axis("blockIdx.x")
598+
by = te.thread_axis("blockIdx.y")
599+
bz = te.thread_axis("blockIdx.z")
600+
ib.scope_attr(tx, "thread_extent", nthread_tx)
601+
ib.scope_attr(bx, "thread_extent", nthread_bx)
602+
ib.scope_attr(by, "thread_extent", nthread_by)
603+
ib.scope_attr(bz, "thread_extent", nthread_bz)
604+
tid = bx * max_threads + tx
605+
i = by
606+
j = tid
607+
k = bz
608+
base_idx = i * num_anchors * box_data_length
609+
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
610+
pass
611+
with ib.else_scope():
612+
with ib.if_scope(j < valid_count[i]):
613+
offset_j = j * box_data_length
614+
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
615+
box_indices[i * num_anchors + j] = j
616+
617+
with ib.new_scope():
618+
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
619+
bx = te.thread_axis("blockIdx.x")
620+
ib.scope_attr(bx, "thread_extent", batch_size)
621+
i = bx
622+
base_idx = i * num_anchors * box_data_length
623+
# Set invalid entry to be -1
624+
with ib.for_range(0, num_anchors - valid_count[i]) as j:
625+
with ib.for_range(0, box_data_length) as k:
626+
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
627+
box_indices[i * num_anchors + j + valid_count[i]] = -1
628+
# Only return max_output_size number of valid boxes
629+
num_valid_boxes[0] = 0
630+
with ib.if_scope(max_output_size > 0):
631+
with ib.for_range(0, valid_count[i]) as j:
632+
offset_j = j * box_data_length
633+
with ib.if_scope(out[base_idx + offset_j] >= 0):
634+
with ib.if_scope(num_valid_boxes[0] == max_output_size):
635+
with ib.for_range(0, box_data_length) as k:
636+
out[base_idx + offset_j + k] = -1.0
637+
box_indices[i * num_anchors + j] = -1
638+
with ib.else_scope():
639+
num_valid_boxes[0] += 1
610640

611641
if return_indices:
612642
with ib.new_scope():

0 commit comments

Comments
 (0)