Skip to content

Commit a3f3dc7

Browse files
apivovarovtqchen
authored andcommitted
Make topi cuda nms_gpu method signature similar to non_max_suppression (#2780)
1 parent d8abc73 commit a3f3dc7

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

topi/python/topi/cuda/nms.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
182182

183183

184184
@non_max_suppression.register(["cuda", "gpu"])
185-
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
186-
topk=-1, id_index=0, invalid_to_bottom=False):
185+
def nms_gpu(data,
186+
valid_count,
187+
max_output_size=-1,
188+
iou_threshold=0.5,
189+
force_suppress=False,
190+
top_k=-1,
191+
id_index=0,
192+
return_indices=True,
193+
invalid_to_bottom=False):
187194
"""Non-maximum suppression operator for object detection.
188195
189196
Parameters
@@ -205,7 +212,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
205212
force_suppress : optional, boolean
206213
Whether to suppress all detections regardless of class_id.
207214
208-
topk : optional, int
215+
top_k : optional, int
209216
Keep maximum top k detections before nms, -1 for no limit.
210217
211218
id_index : optional, int
@@ -229,7 +236,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
229236
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
230237
iou_threshold = 0.7
231238
force_suppress = True
232-
topk = -1
239+
top_k = -1
233240
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
234241
np_data = np.random.uniform(dshape)
235242
np_valid_count = np.array([4])
@@ -273,7 +280,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
273280
[data, sort_tensor, valid_count],
274281
lambda ins, outs: nms_ir(
275282
ins[0], ins[1], ins[2], outs[0], iou_threshold,
276-
force_suppress, topk),
283+
force_suppress, top_k),
277284
dtype="float32",
278285
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
279286
tag="nms")

0 commit comments

Comments
 (0)