Skip to content

Commit 9175c6d

Browse files
codeislife99Ubuntu
andauthored
TRT Dynamic Reshape Fix (#7412)
* Dynamic Reshape * Changes * Add test cases * Add test cases * PR COmments * CI Error * EmptyCommitCIError Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-251.us-east-2.compute.internal>
1 parent 3863e09 commit 9175c6d

File tree

2 files changed

+107
-7
lines changed

2 files changed

+107
-7
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
615615
@_register_external_dynamic_check_func("reshape")
616616
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
617617
"""Check if reshape is supported by TensorRT."""
618-
619618
attrs, args = expr.attrs, expr.args
620619
if args[0].checked_type.dtype != "float32":
621620
logger.info("Only float32 inputs are supported for TensorRT.")
@@ -629,23 +628,23 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
629628
if len(new_shape) == 0 or len(shape) == 0:
630629
logger.info("reshape: Can't reshape to or from scalar.")
631630
return False
632-
633631
dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])
634632

635633
if dynamic_reshape:
636634
# Make sure that the batch dim is unmodified.
637635
if int(new_shape[0]) < 0:
638-
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
636+
for shape_val, new_shape_val in zip(shape[1:], new_shape[1:]):
639637
if not (
640-
isinstance(shape_val, int)
641-
and isinstance(new_shape_val, int)
638+
isinstance(shape_val, (int, tvm.tir.expr.IntImm))
639+
and isinstance(new_shape_val, (int, tvm.tir.expr.IntImm))
642640
and int(shape_val) == int(new_shape_val)
643641
):
644642
return False
645643
elif int(new_shape[0]) > 0:
644+
# Currently we only allow dim[0] to be Any, so this branch will always be False
646645
if not (
647-
isinstance(shape[0], int)
648-
and isinstance(new_shape[0], int)
646+
isinstance(shape[0], (int, tvm.tir.expr.IntImm))
647+
and isinstance(new_shape[0], (int, tvm.tir.expr.IntImm))
649648
and int(shape[0]) == int(new_shape[0])
650649
):
651650
return False

tests/python/contrib/test_tensorrt.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tvm.contrib import graph_runtime, utils
2828
from tvm.runtime.vm import VirtualMachine
2929
from tvm.relay import Any, GlobalVar, transform
30+
from tvm.relay.expr_functor import ExprVisitor
3031
from typing import Dict, Tuple, Union
3132
from tvm.contrib.download import download
3233
from tvm.relay.op.contrib import tensorrt
@@ -631,6 +632,106 @@ def get_graph(x_shape, new_shape):
631632
run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)))
632633

633634

635+
class AreOpsOnGraph(ExprVisitor):
636+
"""
637+
Visits the Graph recursively and checks if it contains ops in the op_list
638+
"""
639+
640+
def __init__(self, op_list):
641+
ExprVisitor.__init__(self)
642+
self.op_list = op_list
643+
self.on_graph = False
644+
645+
def visit_call(self, call):
646+
if isinstance(call.op, tvm.tir.op.Op):
647+
if str(call.op) in self.op_list:
648+
self.on_graph = True
649+
650+
return super().visit_call(call)
651+
652+
def are_ops_on_graph(self, subgraph) -> bool:
653+
"""
654+
This function recursively visits the graph and checks if op_list ops are ongraph"
655+
"""
656+
self.visit(subgraph)
657+
return self.on_graph
658+
659+
660+
def are_ops_on_trt(mod, op_list):
661+
for subgraph in mod.get_global_vars():
662+
name = subgraph.name_hint
663+
op_on_trt = False
664+
op_on_tvm = True
665+
if name == "main":
666+
op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
667+
elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
668+
op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
669+
else:
670+
op_on_tvm &= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
671+
672+
if not op_on_trt or op_on_tvm:
673+
return False
674+
675+
return True
676+
677+
678+
def test_dynamic_reshape():
679+
if skip_codegen_test():
680+
return
681+
682+
def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt):
683+
result_arr = [{} for _ in range(len(x_data_list))]
684+
for use_trt in [True, False]:
685+
x = relay.var("x", shape=x_shape, dtype="float32")
686+
out = relay.reshape(x, new_shape)
687+
f = relay.Function([x], out)
688+
mod = tvm.IRModule()
689+
mod["main"] = f
690+
if use_trt:
691+
mod, _ = tensorrt.partition_for_tensorrt(
692+
mod, params={}, remove_no_mac_subgraphs=False
693+
)
694+
assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt
695+
if not skip_runtime_test():
696+
with relay.build_config(opt_level=3):
697+
relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
698+
699+
for i, x_data in enumerate(x_data_list):
700+
result_arr[i][use_trt] = relay_exec.evaluate()(x_data)
701+
702+
if not skip_runtime_test():
703+
for i in range(len(x_data_list)):
704+
assert_result_dict_holds(result_arr[i])
705+
706+
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
707+
x_shape = (relay.Any(), 3, 2, 3)
708+
x_data_list = [
709+
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
710+
]
711+
new_shape = (-1, 3, 2, 3)
712+
should_offload_to_trt = True
713+
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
714+
715+
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
716+
x_shape = (relay.Any(), 3, 2, 3)
717+
x_data_list = [
718+
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
719+
]
720+
new_shape = (-1, 1, 2, 3)
721+
should_offload_to_trt = False
722+
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
723+
724+
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
725+
x_shape = (1, relay.Any(), 2, 3)
726+
x_data_list = [
727+
np.ones(list(x_shape[:1]) + [dim_value] + list(x_shape)[2:]).astype("float32")
728+
for dim_value in dim_values
729+
]
730+
new_shape = (1, -1, 2, 3)
731+
should_offload_to_trt = False
732+
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
733+
734+
634735
def test_transpose():
635736
def get_graph(x_shape, order):
636737
x = relay.var("x", shape=(x_shape), dtype="float32")

0 commit comments

Comments
 (0)