Skip to content

Commit bad149e

Browse files
authored
[TOPI] Fix GPU Dynamic Op Schedule (#7117)
* Fix GPU dynamic op schedules * Fix dynamic shape nms * Fix * Fix test format
1 parent fb8de5a commit bad149e

File tree

7 files changed

+117
-10
lines changed

7 files changed

+117
-10
lines changed

python/tvm/topi/cuda/conv2d_transpose_nchw.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def _callback(op):
179179
##### space definition begin #####
180180
n, f, y, x = s[conv].op.axis
181181
rc = s[conv].op.reduce_axis[0]
182-
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
182+
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
183+
bs = pad_data.shape[0]
184+
n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
185+
cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
183186
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
184187
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
185188
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
@@ -194,6 +197,8 @@ def _callback(op):
194197

195198
if cfg.is_fallback:
196199
N, F, Y, X = get_const_tuple(conv.shape)
200+
if not isinstance(N, int):
201+
N = 1
197202
_fallback_schedule(N, F, Y, X)
198203

199204
##### space definition end #####

python/tvm/topi/cuda/injective.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,16 @@ def schedule_injective_from_existing(sch, out):
4444
# bandwidth.
4545
vector_width = 4 if out.dtype == "float16" else 1
4646

47+
is_dynamic_output = False
48+
for dim in out.shape:
49+
if not isinstance(dim, tvm.tir.IntImm):
50+
is_dynamic_output = True
51+
break
52+
53+
out_len = utils.prod(out.shape)
54+
4755
try:
48-
const_size = utils.get_const_int(utils.prod(out.shape))
56+
const_size = utils.get_const_int(out_len)
4957
need_block_split = const_size > max_block * num_thread * vector_width
5058
except ValueError:
5159
need_block_split = False
@@ -61,6 +69,9 @@ def schedule_injective_from_existing(sch, out):
6169
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
6270
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
6371
else:
72+
# Use less threads for dynamic shape ops to avoid runtime error.
73+
if is_dynamic_output:
74+
num_thread //= 2
6475
bx, tx = sch[out].split(fused, factor=num_thread)
6576
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
6677
sch[out].bind(bx, te.thread_axis("blockIdx.x"))

python/tvm/topi/cuda/nms.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from tvm.tir import if_then_else
2424
from .sort import argsort, argsort_thrust
25-
from .. import tag
2625

2726

2827
def cuda_atomic_add_rule(op):
@@ -95,7 +94,7 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
9594
with ib.new_scope():
9695
i = te.thread_axis("blockIdx.x")
9796
ib.scope_attr(i, "thread_extent", batch_size)
98-
valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
97+
valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
9998
valid_idx[0] = 0
10099
with ib.for_range(0, num_anchors, name="j") as j:
101100
with ib.if_scope(data[i, j] >= 0):
@@ -654,6 +653,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
654653
return ib.get()
655654

656655

656+
def _fetch_score_ir(data, score, axis):
657+
"""
658+
Fetch score from data.
659+
This routine is required for dynamic shape nms.
660+
"""
661+
batch_size = data.shape[0]
662+
num_anchors = data.shape[1]
663+
elem_length = data.shape[2]
664+
665+
ib = tvm.tir.ir_builder.create()
666+
667+
data = ib.buffer_ptr(data)
668+
score = ib.buffer_ptr(score)
669+
with ib.if_scope(num_anchors > 0):
670+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
671+
nthread_tx = max_threads
672+
nthread_bx = batch_size * num_anchors // max_threads + 1
673+
tx = te.thread_axis("threadIdx.x")
674+
bx = te.thread_axis("blockIdx.x")
675+
ib.scope_attr(tx, "thread_extent", nthread_tx)
676+
ib.scope_attr(bx, "thread_extent", nthread_bx)
677+
678+
tid = bx * max_threads + tx
679+
with ib.if_scope(tid < batch_size * num_anchors):
680+
score[tid] = data[tid * elem_length + axis]
681+
682+
return ib.get()
683+
684+
657685
def non_max_suppression(
658686
data,
659687
valid_count,
@@ -754,7 +782,22 @@ def non_max_suppression(
754782
)
755783
score_axis = score_index
756784
score_shape = (batch_size, num_anchors)
757-
score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
785+
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
786+
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
787+
score_tensor = te.extern(
788+
[score_shape],
789+
[data],
790+
lambda ins, outs: _fetch_score_ir(
791+
ins[0],
792+
outs[0],
793+
score_axis,
794+
),
795+
dtype=[data.dtype],
796+
in_buffers=[data_buf],
797+
out_buffers=[score_buf],
798+
name="fetch_score",
799+
tag="fetch_score",
800+
)
758801
target = tvm.target.Target.current()
759802
if (
760803
target

python/tvm/topi/cuda/sort.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
565565
tag="topk_gpu",
566566
)
567567

568+
if isinstance(k, tvm.tir.IntImm):
569+
k = k.value
570+
568571
if not isinstance(k, int) or k > 0:
569572
beg = [0] * ndim
570573
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")]

src/runtime/contrib/thrust/thrust.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
205205
if (value_dtype == "int32") {
206206
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
207207
for_scatter);
208+
} else if (value_dtype == "int64") {
209+
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
210+
for_scatter);
208211
} else if (value_dtype == "float32") {
209212
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
210213
for_scatter);
@@ -215,6 +218,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
215218
if (value_dtype == "int32") {
216219
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
217220
for_scatter);
221+
} else if (value_dtype == "int64") {
222+
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
223+
for_scatter);
218224
} else if (value_dtype == "float32") {
219225
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
220226
for_scatter);
@@ -225,6 +231,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
225231
if (value_dtype == "int32") {
226232
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
227233
for_scatter);
234+
} else if (value_dtype == "int64") {
235+
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
236+
for_scatter);
228237
} else if (value_dtype == "float32") {
229238
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
230239
for_scatter);

src/runtime/vm/vm.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
245245
std::vector<int> codes(arity);
246246
runtime::TVMArgsSetter setter(values.data(), codes.data());
247247
int idx = 0;
248+
bool is_empty_output = false;
248249
for (Index i = 0; i < arg_count; i++) {
249250
if (const auto* dt_cell = args[i].as<ADTObj>()) {
250251
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
@@ -254,12 +255,24 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
254255
}
255256
} else {
256257
auto nd_array = Downcast<NDArray>(args[i]);
258+
// We can safely skip CallPacked if there is only one
259+
// output and it is empty.
260+
if (i == arg_count - 1 && output_size == 1) {
261+
for (const auto& dim : nd_array.Shape()) {
262+
if (!dim) {
263+
is_empty_output = true;
264+
break;
265+
}
266+
}
267+
}
257268
setter(idx++, nd_array);
258269
}
259270
}
260271

261-
TVMRetValue rv;
262-
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
272+
if (!is_empty_output) {
273+
TVMRetValue rv;
274+
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
275+
}
263276
}
264277

265278
void VirtualMachine::LoadExecutable(const Executable* exec) {

tests/python/relay/test_any.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def check_result(
5454
for kind in ["debug", "vm"]:
5555
targets = targets or tvm.testing.enabled_targets()
5656
for tgt, ctx in targets:
57+
print(tgt)
5758
if disable_targets and tgt in disable_targets:
5859
continue
5960
if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
@@ -199,6 +200,15 @@ def test_any_concat():
199200
ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
200201
check_result([x_np, y_np], mod, ref)
201202

203+
num_inputs = 25
204+
x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)]
205+
z = relay.op.concatenate(x, axis=0)
206+
mod = tvm.IRModule()
207+
mod["main"] = relay.Function(x, z)
208+
x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in range(num_inputs)]
209+
ref = np.concatenate(x_np, axis=0)
210+
check_result(x_np, mod, ref)
211+
202212

203213
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
204214
x = relay.var("x", shape=x_shape, dtype="float32")
@@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw(
572582
mod["main"] = relay.Function([data, kernel], y)
573583
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
574584
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
575-
check_result(
576-
[data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=[("llvm", tvm.cpu())]
577-
)
585+
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)
578586

579587

580588
# TODO(@kevinthesun): Support dynamic input height and width.
@@ -1430,6 +1438,21 @@ def test_non_max_suppression():
14301438
disable_targets=["nvptx"],
14311439
)
14321440

1441+
np_data = np.zeros((1, 0, 6)).astype("float32")
1442+
np_valid_count = np.array([0]).astype("int32")
1443+
np_indices = np.zeros((1, 0)).astype("int32")
1444+
np_max_output_size = -1
1445+
np_indices_result = np.zeros((1, 0))
1446+
np_valid_box_count = np.array([[0]]).astype("int32")
1447+
1448+
check_result(
1449+
[np_data, np_valid_count, np_indices, np_max_output_size],
1450+
mod,
1451+
[np_indices_result, np_valid_box_count],
1452+
only_vm=False,
1453+
disable_targets=["nvptx"],
1454+
)
1455+
14331456

14341457
if __name__ == "__main__":
14351458
pytest.main([__file__])

0 commit comments

Comments
 (0)