Skip to content

Commit 75af88f

Browse files
author
mbrookhart
committed
better parallelize get_valid_counts
1 parent f332512 commit 75af88f

File tree

1 file changed

+153
-41
lines changed

1 file changed

+153
-41
lines changed

python/tvm/topi/cuda/nms.py

Lines changed: 153 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,89 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
118118
return ib.get()
119119

120120

121-
def get_valid_counts_ir(
122-
data, valid_count, out, out_indices, score_threshold, id_index, score_index
123-
):
121+
def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
122+
batch_size = data.shape[0]
123+
num_anchors = data.shape[1]
124+
elem_length = data.shape[2]
125+
126+
ib = tvm.tir.ir_builder.create()
127+
128+
data = ib.buffer_ptr(data)
129+
130+
valid_boxes = ib.buffer_ptr(valid_boxes)
131+
if isinstance(score_threshold, float):
132+
score_threshold = tvm.tir.FloatImm("float32", score_threshold)
133+
id_index = tvm.tir.IntImm("int32", id_index)
134+
score_index = tvm.tir.IntImm("int32", score_index)
135+
136+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
137+
with ib.new_scope():
138+
nthread_tx = max_threads
139+
nthread_bx = num_anchors // max_threads + 1
140+
nthread_by = batch_size
141+
tx = te.thread_axis("threadIdx.x")
142+
bx = te.thread_axis("blockIdx.x")
143+
by = te.thread_axis("blockIdx.y")
144+
ib.scope_attr(tx, "thread_extent", nthread_tx)
145+
ib.scope_attr(bx, "thread_extent", nthread_bx)
146+
ib.scope_attr(by, "thread_extent", nthread_by)
147+
tid = bx * max_threads + tx
148+
149+
with ib.if_scope(tid < num_anchors):
150+
i = by
151+
j = tid
152+
score = data[(i * num_anchors + j) * elem_length + score_index]
153+
with ib.if_scope(
154+
tvm.tir.all(
155+
score > score_threshold,
156+
tvm.tir.any(
157+
id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0
158+
),
159+
)
160+
):
161+
valid_boxes[i * num_anchors + j] = 1
162+
with ib.else_scope():
163+
valid_boxes[i * num_anchors + j] = 0
164+
return ib.get()
165+
166+
167+
def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
168+
batch_size = valid_boxes.shape[0]
169+
num_anchors = valid_boxes.shape[1]
170+
171+
ib = tvm.tir.ir_builder.create()
172+
173+
valid_boxes = ib.buffer_ptr(valid_boxes)
174+
175+
valid_count = ib.buffer_ptr(valid_count)
176+
valid_indices = ib.buffer_ptr(valid_indices)
177+
178+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
179+
with ib.new_scope():
180+
nthread_tx = max_threads
181+
nthread_bx = batch_size // max_threads + 1
182+
tx = te.thread_axis("threadIdx.x")
183+
bx = te.thread_axis("blockIdx.x")
184+
ib.scope_attr(tx, "thread_extent", nthread_tx)
185+
ib.scope_attr(bx, "thread_extent", nthread_bx)
186+
tid = bx * max_threads + tx
187+
# TODO(mbrookhart): Parallelize the sum and cumsum here
188+
current_index = ib.allocate("int32", (1,), name="current_index", scope="local")
189+
with ib.if_scope(tid < batch_size):
190+
current_index[0] = 0
191+
valid_count[tid] = 0
192+
with ib.for_range(0, num_anchors) as j:
193+
idx = tid * num_anchors + j
194+
valid_count[tid] = valid_count[tid] + valid_boxes[idx]
195+
with ib.if_scope(valid_boxes[idx] == 1):
196+
valid_indices[idx] = current_index[0]
197+
current_index[0] = current_index[0] + 1
198+
with ib.else_scope():
199+
valid_indices[idx] = -1
200+
return ib.get()
201+
202+
203+
def get_valid_counts_ir(data, valid_indices, out, out_indices):
124204
"""Low level IR to get valid count of bounding boxes
125205
given a score threshold. Also prepares to move valid boxes to the
126206
top of input data.
@@ -158,47 +238,51 @@ def get_valid_counts_ir(
158238

159239
data = ib.buffer_ptr(data)
160240

161-
valid_count = ib.buffer_ptr(valid_count)
241+
valid_indices = ib.buffer_ptr(valid_indices)
162242
out = ib.buffer_ptr(out)
163243
out_indices = ib.buffer_ptr(out_indices)
164244
one = tvm.tir.const(1, dtype=out.dtype)
165-
if isinstance(score_threshold, float):
166-
score_threshold = tvm.tir.FloatImm("float32", score_threshold)
167-
id_index = tvm.tir.IntImm("int32", id_index)
168-
score_index = tvm.tir.IntImm("int32", score_index)
169245

170246
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
247+
nthread_tx = max_threads
248+
nthread_bx = num_anchors // max_threads + 1
249+
nthread_by = batch_size
250+
nthread_bz = elem_length
171251
with ib.new_scope():
172-
nthread_tx = max_threads
173-
nthread_bx = batch_size // max_threads + 1
174252
tx = te.thread_axis("threadIdx.x")
175253
bx = te.thread_axis("blockIdx.x")
254+
by = te.thread_axis("blockIdx.y")
255+
bz = te.thread_axis("blockIdx.z")
176256
ib.scope_attr(tx, "thread_extent", nthread_tx)
177257
ib.scope_attr(bx, "thread_extent", nthread_bx)
258+
ib.scope_attr(by, "thread_extent", nthread_by)
259+
ib.scope_attr(bz, "thread_extent", nthread_bz)
178260
tid = bx * max_threads + tx
179-
with ib.if_scope(tid < batch_size):
180-
valid_count[tid] = 0
181-
i = tid
182-
with ib.for_range(0, num_anchors) as j:
183-
score = data[(i * num_anchors + j) * elem_length + score_index]
184-
with ib.if_scope(
185-
tvm.tir.all(
186-
score > score_threshold,
187-
tvm.tir.any(
188-
id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0
189-
),
190-
)
191-
):
192-
with ib.for_range(0, elem_length) as k:
193-
out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[
194-
(i * num_anchors + j) * elem_length + k
195-
]
196-
out_indices[i * num_anchors + valid_count[i]] = j
197-
valid_count[i] += 1
198-
with ib.if_scope(j >= valid_count[i]):
199-
with ib.for_range(0, elem_length) as k:
200-
out[(i * num_anchors + j) * elem_length + k] = -one
201-
out_indices[i * num_anchors + j] = -1
261+
with ib.if_scope(tid < num_anchors):
262+
i = by
263+
j = tid
264+
k = bz
265+
out[(i * num_anchors + j) * elem_length + k] = -one
266+
out_indices[i * num_anchors + j] = -1
267+
with ib.new_scope():
268+
tx = te.thread_axis("threadIdx.x")
269+
bx = te.thread_axis("blockIdx.x")
270+
by = te.thread_axis("blockIdx.y")
271+
bz = te.thread_axis("blockIdx.z")
272+
ib.scope_attr(tx, "thread_extent", nthread_tx)
273+
ib.scope_attr(bx, "thread_extent", nthread_bx)
274+
ib.scope_attr(by, "thread_extent", nthread_by)
275+
ib.scope_attr(bz, "thread_extent", nthread_bz)
276+
tid = bx * max_threads + tx
277+
with ib.if_scope(tid < num_anchors):
278+
i = by
279+
j = tid
280+
k = bz
281+
with ib.if_scope(valid_indices[i, tid] >= 0):
282+
out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[
283+
(i * num_anchors + j) * elem_length + k
284+
]
285+
out_indices[i * num_anchors + valid_indices[i, tid]] = j
202286
return ib.get()
203287

204288

@@ -231,23 +315,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
231315
batch_size = data.shape[0]
232316
num_anchors = data.shape[1]
233317
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
318+
valid_boxes_buf = tvm.tir.decl_buffer(
319+
(batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8
320+
)
321+
valid_boxes = te.extern(
322+
[(batch_size, num_anchors)],
323+
[data],
324+
lambda ins, outs: get_valid_boxes_ir(
325+
ins[0], outs[0], score_threshold, id_index, score_index
326+
),
327+
dtype=["int32"],
328+
in_buffers=[data_buf],
329+
out_buffers=[valid_boxes_buf],
330+
name="get_valid_boxes",
331+
tag="get_valid_boxes_gpu",
332+
)
333+
334+
valid_indices_buf = tvm.tir.decl_buffer(
335+
(batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8
336+
)
234337
valid_count_buf = tvm.tir.decl_buffer(
235338
(batch_size,), "int32", "valid_count_buf", data_alignment=8
236339
)
340+
valid_count, valid_indices = te.extern(
341+
[(batch_size,), (batch_size, num_anchors)],
342+
[valid_boxes],
343+
lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),
344+
dtype=["int32"],
345+
in_buffers=[valid_boxes_buf],
346+
out_buffers=[valid_count_buf, valid_indices_buf],
347+
name="get_valid_indices",
348+
tag="get_valid_indices_gpu",
349+
)
350+
237351
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
238352
out_indices_buf = tvm.tir.decl_buffer(
239353
(batch_size, num_anchors), "int32", "out_buf", data_alignment=8
240354
)
241355

242-
valid_count, out, out_indices = te.extern(
243-
[(batch_size,), data.shape, (batch_size, num_anchors)],
244-
[data],
245-
lambda ins, outs: get_valid_counts_ir(
246-
ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index
247-
),
356+
out, out_indices = te.extern(
357+
[data.shape, (batch_size, num_anchors)],
358+
[data, valid_indices],
359+
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]),
248360
dtype=["int32", data.dtype],
249-
in_buffers=[data_buf],
250-
out_buffers=[valid_count_buf, out_buf, out_indices_buf],
361+
in_buffers=[data_buf, valid_indices_buf],
362+
out_buffers=[out_buf, out_indices_buf],
251363
name="get_valid_counts",
252364
tag="get_valid_counts_gpu",
253365
)

0 commit comments

Comments
 (0)