Skip to content

Commit ec1911f

Browse files
author
Ubuntu
committed
Dynamic Reshape
1 parent 132cf6b commit ec1911f

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
635635
if dynamic_reshape:
636636
# Make sure that the batch dim is unmodified.
637637
if int(new_shape[0]) < 0:
638-
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
638+
for shape_val, new_shape_val in zip(shape[1:], new_shape[1:]):
639639
if not (
640640
isinstance(shape_val, int)
641641
and isinstance(new_shape_val, int)

tests/python/contrib/test_tensorrt.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,40 @@ def get_graph(x_shape, new_shape):
631631
run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)))
632632

633633

634+
def test_dynamic_reshape():
635+
if skip_codegen_test():
636+
return
637+
638+
def test_run(batches_to_test, x_shape, new_shape):
639+
x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32")
640+
result_arr = [{} for _ in range(len(batches_to_test))]
641+
for use_trt in [True]:
642+
x = relay.var("x", shape=x_shape, dtype="float32")
643+
out = relay.reshape(x, new_shape)
644+
f = relay.Function([x], out)
645+
mod = tvm.IRModule()
646+
mod["main"] = f
647+
if use_trt:
648+
mod, _ = tensorrt.partition_for_tensorrt(mod, params={})
649+
print(mod)
650+
if not skip_runtime_test():
651+
with relay.build_config(opt_level=3):
652+
relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
653+
654+
for i, batch_size in enumerate(batches_to_test):
655+
result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...])
656+
print(x_data[:batch_size, ...].shape, result_arr[i][use_trt].shape)
657+
658+
if not skip_runtime_test():
659+
for i in range(len(batches_to_test)):
660+
assert_result_dict_holds(result_arr[i])
661+
662+
batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2]
663+
x_shape = (relay.Any(), 3, 2, 3)
664+
new_shape = (-1, 1, 2, 3)
665+
test_run(batches_to_test, x_shape, new_shape)
666+
667+
634668
def test_transpose():
635669
def get_graph(x_shape, order):
636670
x = relay.var("x", shape=(x_shape), dtype="float32")

0 commit comments

Comments
 (0)