|
27 | 27 | from tvm.contrib import graph_runtime, utils
|
28 | 28 | from tvm.runtime.vm import VirtualMachine
|
29 | 29 | from tvm.relay import Any, GlobalVar, transform
|
| 30 | +from tvm.relay.expr_functor import ExprVisitor |
30 | 31 | from typing import Dict, Tuple, Union
|
31 | 32 | from tvm.contrib.download import download
|
32 | 33 | from tvm.relay.op.contrib import tensorrt
|
@@ -631,6 +632,106 @@ def get_graph(x_shape, new_shape):
|
631 | 632 | run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)))
|
632 | 633 |
|
633 | 634 |
|
| 635 | +class AreOpsOnGraph(ExprVisitor): |
| 636 | + """ |
| 637 | + Visits the Graph recursively and checks if it contains ops in the op_list |
| 638 | + """ |
| 639 | + |
| 640 | + def __init__(self, op_list): |
| 641 | + ExprVisitor.__init__(self) |
| 642 | + self.op_list = op_list |
| 643 | + self.on_graph = False |
| 644 | + |
| 645 | + def visit_call(self, call): |
| 646 | + if isinstance(call.op, tvm.tir.op.Op): |
| 647 | + if str(call.op) in self.op_list: |
| 648 | + self.on_graph = True |
| 649 | + |
| 650 | + return super().visit_call(call) |
| 651 | + |
| 652 | + def are_ops_on_graph(self, subgraph) -> bool: |
| 653 | + """ |
| 654 | + This function recursively visits the graph and checks if op_list ops are ongraph" |
| 655 | + """ |
| 656 | + self.visit(subgraph) |
| 657 | + return self.on_graph |
| 658 | + |
| 659 | + |
| 660 | +def are_ops_on_trt(mod, op_list): |
| 661 | + for subgraph in mod.get_global_vars(): |
| 662 | + name = subgraph.name_hint |
| 663 | + op_on_trt = False |
| 664 | + op_on_tvm = True |
| 665 | + if name == "main": |
| 666 | + op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) |
| 667 | + elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt": |
| 668 | + op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) |
| 669 | + else: |
| 670 | + op_on_tvm &= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) |
| 671 | + |
| 672 | + if not op_on_trt or op_on_tvm: |
| 673 | + return False |
| 674 | + |
| 675 | + return True |
| 676 | + |
| 677 | + |
| 678 | +def test_dynamic_reshape(): |
| 679 | + if skip_codegen_test(): |
| 680 | + return |
| 681 | + |
| 682 | + def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): |
| 683 | + result_arr = [{} for _ in range(len(x_data_list))] |
| 684 | + for use_trt in [True, False]: |
| 685 | + x = relay.var("x", shape=x_shape, dtype="float32") |
| 686 | + out = relay.reshape(x, new_shape) |
| 687 | + f = relay.Function([x], out) |
| 688 | + mod = tvm.IRModule() |
| 689 | + mod["main"] = f |
| 690 | + if use_trt: |
| 691 | + mod, _ = tensorrt.partition_for_tensorrt( |
| 692 | + mod, params={}, remove_no_mac_subgraphs=False |
| 693 | + ) |
| 694 | + assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt |
| 695 | + if not skip_runtime_test(): |
| 696 | + with relay.build_config(opt_level=3): |
| 697 | + relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") |
| 698 | + |
| 699 | + for i, x_data in enumerate(x_data_list): |
| 700 | + result_arr[i][use_trt] = relay_exec.evaluate()(x_data) |
| 701 | + |
| 702 | + if not skip_runtime_test(): |
| 703 | + for i in range(len(x_data_list)): |
| 704 | + assert_result_dict_holds(result_arr[i]) |
| 705 | + |
| 706 | + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] |
| 707 | + x_shape = (relay.Any(), 3, 2, 3) |
| 708 | + x_data_list = [ |
| 709 | + np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values |
| 710 | + ] |
| 711 | + new_shape = (-1, 3, 2, 3) |
| 712 | + should_offload_to_trt = True |
| 713 | + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) |
| 714 | + |
| 715 | + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] |
| 716 | + x_shape = (relay.Any(), 3, 2, 3) |
| 717 | + x_data_list = [ |
| 718 | + np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values |
| 719 | + ] |
| 720 | + new_shape = (-1, 1, 2, 3) |
| 721 | + should_offload_to_trt = False |
| 722 | + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) |
| 723 | + |
| 724 | + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] |
| 725 | + x_shape = (1, relay.Any(), 2, 3) |
| 726 | + x_data_list = [ |
| 727 | + np.ones(list(x_shape[:1]) + [dim_value] + list(x_shape)[2:]).astype("float32") |
| 728 | + for dim_value in dim_values |
| 729 | + ] |
| 730 | + new_shape = (1, -1, 2, 3) |
| 731 | + should_offload_to_trt = False |
| 732 | + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) |
| 733 | + |
| 734 | + |
634 | 735 | def test_transpose():
|
635 | 736 | def get_graph(x_shape, order):
|
636 | 737 | x = relay.var("x", shape=(x_shape), dtype="float32")
|
|
0 commit comments