From 4f7b19fb9ff3981fd03b01f67778da32dd25d46b Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 19 Oct 2020 16:03:06 -0700 Subject: [PATCH] [BYOC][TensorRT] TensorRT BYOC integration (#6395) * TensorRT integration using JSONRuntime Support input nodes with multiple data entries Fix failing tests Support layout transform, add engine caching Add comment Add PruneSubgraph pass Use prune_subgraph pass, make params member of trt runtime class Hide deprecation warnings coming from TRT headers Remove general prune subgraph Save/load use_implicit_batch and workspace size Clean up Fix cpp lint Addressing review comments Refactor tests Use relay.bind instead of VarReplacer. Improve some annotation functions Add TRT docs Use DLOG, formatting Use logging.info instead of print also refactor integ tests also refactor integ tests Formatting Formatting Format python fix python format Fix pylint Fix sphinx precheck Add tensorrt.rst to toctree Allow codegen to be tested when TRT runtime is not available. Enable TRT codegen in CI linty Address more comments Formatting Formatting * Documentation changes * Address comments * Rename USE_TENSORRT->USE_TENSORRT_CODEGEN and USE_TENSORRT_GRAPH_RUNTIME->USE_TENSORRT_RUNTIME * Fix comment typo * Test CI without TRT codegen enabled * formatting * Enable USE_TENSORRT_CODEGEN in CI * Change file_util.h -> file_utils.h --- CMakeLists.txt | 3 + cmake/config.cmake | 10 + cmake/modules/contrib/TensorRT.cmake | 54 + docs/deploy/index.rst | 1 + docs/deploy/tensorrt.rst | 297 +++++ python/tvm/relay/op/contrib/__init__.py | 1 + python/tvm/relay/op/contrib/tensorrt.py | 769 ++++++++++++ src/relay/backend/contrib/tensorrt/codegen.cc | 240 ++++ .../contrib/tensorrt/tensorrt_builder.cc | 222 ++++ .../contrib/tensorrt/tensorrt_builder.h | 159 +++ .../contrib/tensorrt/tensorrt_logger.h | 78 ++ src/runtime/contrib/tensorrt/tensorrt_ops.cc | 1070 +++++++++++++++++ src/runtime/contrib/tensorrt/tensorrt_ops.h | 207 ++++ .../contrib/tensorrt/tensorrt_runtime.cc | 312 +++++ src/runtime/contrib/tensorrt/tensorrt_utils.h | 74 ++ tests/python/contrib/test_tensorrt.py | 905 ++++++++++++++ tests/scripts/task_config_build_gpu.sh | 3 +- 17 files changed, 4403 insertions(+), 2 deletions(-) create mode 100644 cmake/modules/contrib/TensorRT.cmake create mode 100644 docs/deploy/tensorrt.rst create mode 100644 python/tvm/relay/op/contrib/tensorrt.py create mode 100644 src/relay/backend/contrib/tensorrt/codegen.cc create mode 100644 src/runtime/contrib/tensorrt/tensorrt_builder.cc create mode 100644 src/runtime/contrib/tensorrt/tensorrt_builder.h create mode 100644 src/runtime/contrib/tensorrt/tensorrt_logger.h create mode 100644 src/runtime/contrib/tensorrt/tensorrt_ops.cc create mode 100644 src/runtime/contrib/tensorrt/tensorrt_ops.h create mode 100644 src/runtime/contrib/tensorrt/tensorrt_runtime.cc create mode 100644 src/runtime/contrib/tensorrt/tensorrt_utils.h create mode 100644 tests/python/contrib/test_tensorrt.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 33c720c4cce4..d07f55f06ad0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,8 @@ tvm_option(USE_COREML "Build with coreml support" OFF) tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) +tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) +tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -363,6 +365,7 @@ include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/ONNX.cmake) include(cmake/modules/contrib/ArmComputeLib.cmake) +include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 1d465b2fe389..b220f3b0b9f0 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -222,6 +222,16 @@ set(USE_ETHOSN OFF) # otherwise use ETHOSN_HW (OFF) to use the software test infrastructure set(USE_ETHOSN_HW OFF) +# Whether to build with TensorRT codegen or runtime +# Examples are available here: docs/deploy/tensorrt.rst. +# +# USE_TENSORRT_CODEGEN - Support for compiling a relay graph where supported operators are +# offloaded to TensorRT. OFF/ON +# USE_TENSORRT_RUNTIME - Support for running TensorRT compiled modules, requires presense of +# TensorRT library. OFF/ON/"path/to/TensorRT" +set(USE_TENSORRT_CODEGEN OFF) +set(USE_TENSORRT_RUNTIME OFF) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake new file mode 100644 index 000000000000..1536d23205a7 --- /dev/null +++ b/cmake/modules/contrib/TensorRT.cmake @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TensorRT Codegen only. This can be enabled independently of USE_TENSORRT_RUNTIME to enable +# compilation of TensorRT modules without requiring TensorRT to be installed. The compiled modules +# will only be able to be executed using a TVM built with USE_TENSORRT_RUNTIME=ON. +if(USE_TENSORRT_CODEGEN) + message(STATUS "Build with TensorRT codegen") + file(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc) + set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + file(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/tensorrt_runtime.cc) + set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + list(APPEND COMPILER_SRCS ${COMPILER_TENSORRT_SRCS}) + list(APPEND COMPILER_SRCS ${RUNTIME_TENSORRT_SRCS}) +endif() + +# TensorRT Runtime +if(USE_TENSORRT_RUNTIME) + if(IS_DIRECTORY ${USE_TENSORRT_RUNTIME}) + set(TENSORRT_ROOT_DIR ${USE_TENSORRT_RUNTIME}) + message(STATUS "Custom TensorRT path: " ${TENSORRT_ROOT_DIR}) + endif() + find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES include) + find_library(TENSORRT_LIB_DIR nvinfer HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib) + find_package_handle_standard_args(TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIB_DIR) + if(NOT TENSORRT_FOUND) + message(ERROR "Could not find TensorRT.") + endif() + message(STATUS "TENSORRT_LIB_DIR: " ${TENSORRT_LIB_DIR}) + include_directories(${TENSORRT_INCLUDE_DIR}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TENSORRT_LIB_DIR}) + + # TRT runtime sources + file(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/*.cc) + set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") + list(APPEND RUNTIME_SRCS ${RUNTIME_TENSORRT_SRCS}) + + # Set defines + add_definitions(-DTVM_GRAPH_RUNTIME_TENSORRT) +endif() diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index b38a7f561ab3..68843ba18248 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -69,3 +69,4 @@ target device without relying on RPC. see the following resources on how to do s integrate hls arm_compute_lib + tensorrt diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst new file mode 100644 index 000000000000..27f11e9b5377 --- /dev/null +++ b/docs/deploy/tensorrt.rst @@ -0,0 +1,297 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Relay TensorRT Integration +========================== +**Author**: `Trevor Morris `_ + +Introduction +------------ + +NVIDIA TensorRT is a library for optimized deep learning inference. This integration will offload as +many operators as possible from Relay to TensorRT, providing a performance boost on NVIDIA GPUs +without the need to tune schedules. + +This guide will demonstrate how to install TensorRT and build TVM with TensorRT BYOC and runtime +enabled. It will also provide example code to compile and run a ResNet-18 model using TensorRT and +how to configure the compilation and runtime settings. Finally, we document the supported operators +and how to extend the integration to support other operators. + +Installing TensorRT +------------------- + +In order to download TensorRT, you will need to create an NVIDIA Developer program account. Please +see NVIDIA's documentation for more info: +https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. If you have a Jetson device +such as a TX1, TX2, Xavier, or Nano, TensorRT will already be installed on the device via the +JetPack SDK. + +There are two methods to install TensorRT: + +* System install via deb or rpm package. +* Tar file installation. + +With the tar file installation method, you must provide the path of the extracted tar archive to +USE_TENSORRT_RUNTIME=/path/to/TensorRT. With the system install method, +USE_TENSORRT_RUNTIME=ON will automatically locate your installation. + +Building TVM with TensorRT support +---------------------------------- + +There are two separate build flags for TensorRT integration in TVM. These flags also enable +cross-compilation: USE_TENSORRT_CODEGEN=ON will also you to build a module with TensorRT support on +a host machine, while USE_TENSORRT_RUNTIME=ON will enable the TVM runtime on an edge device to +execute the TensorRT module. You should enable both if you want to compile and also execute models +with the same TVM build. + +* USE_TENSORRT_CODEGEN=ON/OFF - This flag will enable compiling a TensorRT module, which does not require any + TensorRT library. +* USE_TENSORRT_RUNTIME=ON/OFF/path-to-TensorRT - This flag will enable the TensorRT runtime module. + This will build TVM against the installed TensorRT library. + +Example setting in config.cmake file: + +.. code:: cmake + + set(USE_TENSORRT_CODEGEN ON) + set(USE_TENSORRT_RUNTIME /home/ubuntu/TensorRT-7.0.0.11) + + +Build and Deploy ResNet-18 with TensorRT +---------------------------------------- + +Create a Relay graph from a MXNet ResNet-18 model. + +.. code:: python + + import tvm + from tvm import relay + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + dtype = "float32" + input_shape = (1, 3, 224, 224) + block = get_model('resnet18_v1', pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + + +Annotate and partition the graph for TensorRT. All ops which are supported by the TensorRT +integration will be marked and offloaded to TensorRT. The rest of the ops will go through the +regular TVM CUDA compilation and code generation. + +.. code:: python + + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + mod, config = partition_for_tensorrt(mod, params) + + +Build the Relay graph, using the new module and config returned by partition_for_tensorrt. The +target must always be a cuda target. ``partition_for_tensorrt`` will automatically fill out the +required values in the config, so there is no need to modify it - just pass it along to the +PassContext so the values can be read during compilation. + +.. code:: python + + target = "cuda" + with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}): + lib = relay.build(mod, target=target, params=params) + + +Export the module. + +.. code:: python + + lib.export_library('compiled.so') + + +Load module and run inference on the target machine, which must be built with +``USE_TENSORRT_RUNTIME`` enabled. The first run will take longer because the TensorRT engine will +have to be built. + +.. code:: python + + ctx = tvm.gpu(0) + loaded_lib = tvm.runtime.load_module('compiled.so') + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx)) + input_data = np.random.uniform(0, 1, input_shape).astype(dtype) + gen_module.run(data=input_data) + + +Partitioning and Compilation Settings +------------------------------------- + +There are some options which can be configured in ``partition_for_tensorrt``. + +* ``version`` - TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled + with USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead. The version + will affect which ops can be partitioned to TensorRT. +* ``use_implicit_batch`` - Use TensorRT implicit batch mode (default true). Setting to false will + enable explicit batch mode which will widen supported operators to include those which modify the + batch dimension, but may reduce performance for some models. +* ``remove_no_mac_subgraphs`` - A heuristic to improve performance. Removes subgraphs which have + been partitioned for TensorRT if they do not have any multiply-accumulate operations. The removed + subgraphs will go through TVM's standard compilation instead. +* ``max_workspace_size`` - How many bytes of workspace size to allow each subgraph to use for + TensorRT engine creation. See TensorRT documentation for more info. Can be overriden at runtime. + + +Runtime Settings +---------------- + +There are some additional options which can be configured at runtime using environment variables. + +* Automatic FP16 Conversion - Environment variable ``TVM_TENSORRT_USE_FP16=1`` can be set to + automatically convert the TensorRT components of your model to 16-bit floating point precision. + This can greatly increase performance, but may cause some slight loss in the model accuracy. +* Caching TensorRT Engines - During the first inference, the runtime will invoke the TensorRT API + to build an engine. This can be time consuming, so you can set ``TVM_TENSORRT_CACHE_DIR`` to + point to a directory to save these built engines to on the disk. The next time you load the model + and give it the same directory, the runtime will load the already built engines to avoid the long + warmup time. A unique directory is required for each model. +* TensorRT has a paramter to configure the maximum amount of scratch space that each layer in the + model can use. It is generally best to use the highest value which does not cause you to run out + of memory. You can use ``TVM_TENSORRT_MAX_WORKSPACE_SIZE`` to override this by specifying the + workspace size in bytes you would like to use. + + +Operator support +---------------- ++------------------------+------------------------------------+ +| Relay Node | Remarks | ++========================+====================================+ +| nn.relu | | ++------------------------+------------------------------------+ +| sigmoid | | ++------------------------+------------------------------------+ +| tanh | | ++------------------------+------------------------------------+ +| nn.batch_norm | | ++------------------------+------------------------------------+ +| nn.softmax | | ++------------------------+------------------------------------+ +| nn.conv2d | | ++------------------------+------------------------------------+ +| nn.dense | | ++------------------------+------------------------------------+ +| nn.bias_add | | ++------------------------+------------------------------------+ +| add | | ++------------------------+------------------------------------+ +| subtract | | ++------------------------+------------------------------------+ +| multiply | | ++------------------------+------------------------------------+ +| divide | | ++------------------------+------------------------------------+ +| power | | ++------------------------+------------------------------------+ +| maximum | | ++------------------------+------------------------------------+ +| minimum | | ++------------------------+------------------------------------+ +| nn.max_pool2d | | ++------------------------+------------------------------------+ +| nn.avg_pool2d | | ++------------------------+------------------------------------+ +| nn.global_max_pool2d | | ++------------------------+------------------------------------+ +| nn.global_avg_pool2d | | ++------------------------+------------------------------------+ +| exp | | ++------------------------+------------------------------------+ +| log | | ++------------------------+------------------------------------+ +| sqrt | | ++------------------------+------------------------------------+ +| abs | | ++------------------------+------------------------------------+ +| negative | | ++------------------------+------------------------------------+ +| nn.batch_flatten | | ++------------------------+------------------------------------+ +| expand_dims | | ++------------------------+------------------------------------+ +| squeeze | | ++------------------------+------------------------------------+ +| concatenate | | ++------------------------+------------------------------------+ +| nn.conv2d_transpose | | ++------------------------+------------------------------------+ +| transpose | | ++------------------------+------------------------------------+ +| layout_transform | | ++------------------------+------------------------------------+ +| reshape | | ++------------------------+------------------------------------+ +| nn.pad | | ++------------------------+------------------------------------+ +| sum | | ++------------------------+------------------------------------+ +| prod | | ++------------------------+------------------------------------+ +| max | | ++------------------------+------------------------------------+ +| min | | ++------------------------+------------------------------------+ +| mean | | ++------------------------+------------------------------------+ +| nn.adaptive_max_pool2d | | ++------------------------+------------------------------------+ +| nn.adaptive_avg_pool2d | | ++------------------------+------------------------------------+ +| clip | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| nn.leaky_relu | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| sin | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| cos | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| atan | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| ceil | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| floor | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| strided_slice | Requires TensorRT 5.1.5 or greater | ++------------------------+------------------------------------+ +| nn.conv3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.max_pool3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.avg_pool3d | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ +| nn.conv3d_transpose | Requires TensorRT 6.0.1 or greater | ++------------------------+------------------------------------+ + + +Adding a new operator +--------------------- +To add support for a new operator, there are a series of files we need to make changes to: + +* `src/runtime/contrib/tensorrt/tensorrt_ops.cc` Create a new op converter class which + implements the ``TensorRTOpConverter`` interface. You must implement the constructor to specify how + many inputs there are and whether they are tensors or weights. You must also implement the + ``Convert`` method to perform the conversion. This is done by using the inputs, attributes, and + network from params to add the new TensorRT layers and push the layer outputs. You can use the + existing converters as an example. Finally, register your new op conventer in the + ``GetOpConverters()`` map. +* `python/relay/op/contrib/tensorrt.py` This file contains the annotation rules for TensorRT. These + determine which operators and their attributes that are supported. You must register an annotation + function for the relay operator and specify which attributes are supported by your converter, by + checking the attributes are returning true or false. +* `tests/python/contrib/test_tensorrt.py` Add unit tests for the given operator. diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index dbcd8055d30b..49abf36134b4 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -22,3 +22,4 @@ from .dnnl import * from .coreml import * from .ethosn import * +from .tensorrt import * diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py new file mode 100644 index 000000000000..a0e23a043a72 --- /dev/null +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -0,0 +1,769 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""TensorRT supported operators.""" +import logging +import numpy as np +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr import Call, Constant, Tuple, GlobalVar +from tvm.relay.expr_functor import ExprMutator + +logger = logging.getLogger("TensorRT") + + +def is_tensorrt_runtime_enabled(): + """Check if the TensorRT graph runtime is present. + Returns + ------- + ret: bool + True if present, False if not. + """ + check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True) + if check_enabled: + return check_enabled() + return False + + +def get_tensorrt_version(): + """Gets the version of TensorRT that TVM is built against or is targeting. + + Returns + ------- + ret: Tuple[int, int, int] + TensorRT version as a tuple of major, minor, and patch number. If TVM + is not built with TensorRT, the value set by set_tensorrt_version() is returned instead. + """ + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) + return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) + + +def get_tensorrt_use_implicit_batch_mode(): + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch + logger.warning( + "PassContext has no relay.ext.tensorrt.options config, using default value " + "use_implicit_batch=True." + ) + return True + + +def get_tensorrt_remove_no_mac_subgraphs(): + pass_ctx = tvm.transform.PassContext.current() + if "relay.ext.tensorrt.options" in pass_ctx.config: + return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs + logger.warning( + "PassContext has no relay.ext.tensorrt.options config, using default value " + "remove_no_mac_subgraphs=False." + ) + return False + + +def partition_for_tensorrt( + mod, + params=None, + version=None, + use_implicit_batch=True, + remove_no_mac_subgraphs=False, + max_workspace_size=1 << 30, +): + """Partition the graph greedily offloading supported operators to TensorRT. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + version : Optional[Tuple[int, int, int]] + TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with + USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead. + use_implicit_batch : Optional[bool] + Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch + mode which will widen supported operators to include those which modify the batch dimension, + but may reduce performance for some models. + remove_no_mac_subgraphs : Optional[bool] + Removes subgraphs which have been partitioned for TensorRT if they do not have any + multiply-accumulate operations. The removed subgraphs will go through TVM's standard + compilation instead. Can improve performance. + max_workspace_size : Optional[int] + How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. + See TensorRT documentation for more info. + Returns + ------- + mod_and_config : Tuple[Module, Dict[str, Any]] + A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options" + configuration which should be given to PassContext when building. + """ + config = { + "use_implicit_batch": use_implicit_batch, + "max_workspace_size": max_workspace_size, + "remove_no_mac_subgraphs": remove_no_mac_subgraphs, + } + if version: + assert isinstance(version, tuple) and len(version) == 3 + config["tensorrt_version"] = version + else: + linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) + if not linked_version: + logger.warning( + "TVM was not built against TensorRT and no version was provided to " + "partition_for_tensorrt. Defaulting to 6.0.1" + ) + linked_version = (6, 0, 1) + config["tensorrt_version"] = linked_version + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + seq = tvm.transform.Sequential( + [ + transform.InferType(), + RemoveDropoutPass(), + transform.RemoveUnusedFunctions(), + transform.ConvertLayout( + {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]} + ), + transform.FoldConstant(), + transform.AnnotateTarget("tensorrt"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + transform.InferType(), + ] + ) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + mod = seq(mod) + mod = prune_tensorrt_subgraphs(mod) + return mod, config + + +def _register_external_op_helper_with_checker(op_name, checker): + @tvm.ir.register_op_attr(op_name, "target.tensorrt") + def _func_wrapper(attrs, args): + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + return checker(attrs, args, op_name) + + return _func_wrapper + + +def _register_external_op_helper(op_name, supported=True): + return _register_external_op_helper_with_checker( + op_name, lambda attrs, args, op_name: supported + ) + + +# Ops which are always supported +_register_external_op_helper("nn.relu") +_register_external_op_helper("sigmoid") +_register_external_op_helper("tanh") +_register_external_op_helper("subtract") +_register_external_op_helper("multiply") +_register_external_op_helper("divide") +_register_external_op_helper("power") +_register_external_op_helper("maximum") +_register_external_op_helper("minimum") +_register_external_op_helper("exp") +_register_external_op_helper("log") +_register_external_op_helper("sqrt") +_register_external_op_helper("abs") +_register_external_op_helper("negative") +_register_external_op_helper("nn.batch_flatten") +_register_external_op_helper("clip") + + +@tvm.ir.register_op_attr("add", "target.tensorrt") +def add_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if add is supported by TensorRT.""" + + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if ( + not get_tensorrt_use_implicit_batch_mode() + and (isinstance(args[0], Constant) or isinstance(args[1], Constant)) + and args[0].checked_type.shape[0] == args[1].checked_type.shape[0] + and args[0].checked_type.shape[0] != 1 + and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3) + ): + logger.info("add: bug in TRT with adding batched constants.") + return False + return True + + +@tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt") +def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.batch_norm is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if int(attrs.axis) not in (1, 3): + logger.info("nn.batch_norm: axis is %d but must be 1 or 3.", int(attrs.axis)) + return False + return True + + +@tvm.ir.register_op_attr("nn.softmax", "target.tensorrt") +def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.softmax is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: + logger.info("nn.softmax: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt") +def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCHW": + logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIHW": + logger.info("nn.conv2d: kernel_layout is %s but must be OIHW.", attrs.kernel_layout) + return False + if attrs.out_layout and attrs.out_layout != "NCHW": + logger.info("nn.conv2d: out_layout is %s but must be NCHW.", attrs.out_layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.dense", "target.tensorrt") +def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if dense is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + input_rank = len(args[0].checked_type.shape) + weight_rank = len(args[1].checked_type.shape) + if input_rank not in (2, 3, 4): + logger.info("nn.dense: input has rank %d but must be 2, 3 or 4.", input_rank) + return False + if weight_rank != 2: + logger.info("nn.dense: weight has rank %d but must be 2.", weight_rank) + return False + return True + + +@tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt") +def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.bias_add is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + input_rank = len(args[0].checked_type.shape) + if input_rank not in (2, 3, 4): + logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank) + return False + return True + + +@tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt") +def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5): + logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.") + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt") +def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout) + return False + if ( + attrs.count_include_pad + and len(attrs.padding) == 4 + and ( + int(attrs.padding[0]) != int(attrs.padding[2]) + or int(attrs.padding[1]) != int(attrs.padding[3]) + ) + ): + logger.info( + "nn.avg_pool2d: inclusive-counted blended or average " + "pooling is not supported in combination with asymmetric padding" + ) + return False + if attrs.ceil_mode and get_tensorrt_version() < (5, 1, 5): + logger.info("nn.avg_pool2d: ceil_mode=True requires TensorRT 5.1.5 or greater.") + return False + return True + + +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt") +def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.global_max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt") +def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.global_avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.layout != "NCHW": + logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("expand_dims", "target.tensorrt") +def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if expand_dims is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: + logger.info("expand_dims: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("squeeze", "target.tensorrt") +def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if squeeze is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not attrs.axis: + logger.info("squeeze: must explicitly set axis.") + return False + if get_tensorrt_use_implicit_batch_mode() and any([axis == 0 for axis in map(int, attrs.axis)]): + logger.info("squeeze: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("concatenate", "target.tensorrt") +def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if concatenate is supported by TensorRT.""" + if any([x.dtype != "float32" for x in args[0].checked_type.fields]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not get_tensorrt_use_implicit_batch_mode(): + return True + if int(attrs.axis) == 0: + logger.info("concatenate: can't modify batch dimension.") + return False + if isinstance(args[0], Tuple): + for tuple_input in args[0].fields: + if isinstance(tuple_input, Constant): + logger.info("concatenate: can't concatenate tensors with constants.") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt") +def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv2d_transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.data_layout != "NCHW": + logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIHW": + logger.info( + "nn.conv2d_transpose: kernel_layout is %s but must be OIHW.", attrs.kernel_layout + ) + return False + if attrs.out_layout and attrs.out_layout != "NCHW": + logger.info("nn.conv2d_transpose: out_layout is %s but must be NCHW.", attrs.out_layout) + return False + if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]): + logger.info("nn.conv2d_transpose: dilation rate must be 1.") + return False + return True + + +@tvm.ir.register_op_attr("transpose", "target.tensorrt") +def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0: + logger.info("transpose: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("layout_transform", "target.tensorrt") +def layout_transform_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if layout_transform is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if (attrs.src_layout, attrs.dst_layout) not in [ + ("NCHW", "NHWC"), + ("NHWC", "NCHW"), + ("NDHWC", "NCDHW"), + ("NCDHW", "NDHWC"), + ]: + logger.info( + "layout_transform: %s to %s is not supported.", attrs.src_layout, attrs.dst_layout + ) + return False + return True + + +@tvm.ir.register_op_attr("reshape", "target.tensorrt") +def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if reshape is supported by TensorRT.""" + if args[0].checked_type.dtype != "float32": + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if any([x < -1 for x in map(int, attrs.newshape)]): + logger.info("reshape: new shape dims must be explicit.") + return False + if get_tensorrt_use_implicit_batch_mode(): + shape = list(map(int, args[0].checked_type.shape)) + new_shape = list(map(int, attrs.newshape)) + if len(new_shape) == 0 or len(shape) == 0: + logger.info("reshape: Can't reshape to or from scalar.") + return False + # TRT cannot modify batch dimension. + original_volume = np.prod(shape) + # First, resolve 0. + for i, value in enumerate(new_shape): + if value == 0: + new_shape[i] = shape[i] + # Resolve -1. + for i, value in enumerate(new_shape): + if value == -1: + new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1]) + if shape[0] != new_shape[0]: + logger.info("reshape: can't modify batch dimension.") + return False + return True + + +@tvm.ir.register_op_attr("nn.pad", "target.tensorrt") +def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.pad is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if attrs.pad_mode != "constant": + logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode) + return False + if float(attrs.pad_value) != 0.0: + logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value)) + return False + if any([x != 0 for x in attrs.pad_width[0]]) or any([x != 0 for x in attrs.pad_width[1]]): + logger.info("nn.pad: can't pad batch or channel dimensions.") + return False + if len(attrs.pad_width) == 5 and any([x != 0 for x in attrs.pad_width[2]]): + logger.info("nn.pad: can only pad last two dimensions for 5D inputs.") + return True + + +def reduce_annotate_fn(attrs, args, op_name): + """Helper for reduce operations.""" + if not attrs.axis or len(attrs.axis) == 0: + logger.info("%s: cannot reduce to scalar.", op_name) + return False + if attrs.exclude: + logger.info("%s: exclude not supported.", op_name) + return False + if get_tensorrt_use_implicit_batch_mode() and any([x == 0 for x in map(int, attrs.axis)]): + logger.info("%s: can't modify batch dimension.", op_name) + return False + return True + + +_register_external_op_helper_with_checker("sum", reduce_annotate_fn) +_register_external_op_helper_with_checker("prod", reduce_annotate_fn) +_register_external_op_helper_with_checker("max", reduce_annotate_fn) +_register_external_op_helper_with_checker("min", reduce_annotate_fn) +_register_external_op_helper_with_checker("mean", reduce_annotate_fn) + + +def trt_version_annotate_fn(version): + """Helper for ops which require a minimum TRT version""" + + def _func_wrapper(attrs, args, op_name): + if get_tensorrt_version() < version: + logger.info( + "%s: requires TensorRT version %s or higher.", op_name, ".".join(map(str, version)) + ) + return False + return True + + return _func_wrapper + + +_register_external_op_helper_with_checker("nn.leaky_relu", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("sin", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("cos", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("atan", trt_version_annotate_fn((5, 1, 5))) +_register_external_op_helper_with_checker("ceil", trt_version_annotate_fn((5, 1, 5))) + + +@tvm.ir.register_op_attr("strided_slice", "target.tensorrt") +def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if strided_slice is supported by TensorRT.""" + if args[0].checked_type.dtype != "float32": + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"): + return False + if get_tensorrt_use_implicit_batch_mode(): + batch_dim_begin_modified = attrs.begin[0] is not None and int(attrs.begin[0]) != 0 + batch_dim_end_modified = ( + attrs.end[0] is not None + and int(attrs.end[0]) != -1 + and int(attrs.end[0]) != int(args[0].checked_type.shape[0]) + ) + if batch_dim_begin_modified or batch_dim_end_modified: + logger.info("strided_slice: can't modify batch dimension.") + return False + if any([x is not None and x <= 0 for x in attrs.strides]): + logger.info("strided_slice: stride must be positive") + return False + return True + + +@tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt") +def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.adaptive_max_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): + logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).") + return False + return True + + +@tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt") +def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.adaptive_avg_pool2d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): + logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).") + return False + return True + + +@tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt") +def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"): + return False + if attrs.data_layout != "NCDHW": + logger.info("nn.conv3d: data_layout is %s but must be NCDHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIDHW": + logger.info("nn.conv3d: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout) + return False + if attrs.out_layout and attrs.out_layout != "NCDHW": + logger.info("nn.conv3d: out_layout is %s but must be NCDHW.", attrs.out_layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt") +def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.max_pool3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"): + return False + if attrs.layout != "NCDHW": + logger.info("nn.max_pool3d: layout is %s but must be NCDHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt") +def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.avg_pool3d is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"): + return False + if attrs.layout != "NCDHW": + logger.info("nn.avg_pool3d: layout is %s but must be NCDHW.", attrs.layout) + return False + return True + + +@tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt") +def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable + """Check if nn.conv3d_transpose is supported by TensorRT.""" + if any([x.checked_type.dtype != "float32" for x in args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"): + return False + if attrs.data_layout != "NCDHW": + logger.info("nn.conv3d_transpose: data_layout is %s but must be NCDHW.", attrs.data_layout) + return False + if attrs.kernel_layout != "OIDHW": + logger.info( + "nn.conv3d_transpose: kernel_layout is %s but must be OIDHW.", attrs.kernel_layout + ) + return False + if attrs.out_layout and attrs.out_layout != "NCDHW": + logger.info("nn.conv3d_transpose: out_layout is %s but must be NCDHW.", attrs.out_layout) + return False + if attrs.dilation and any([rate != 1 for rate in map(int, attrs.dilation)]): + logger.info("nn.conv3d_transpose: dilation rate must be 1.") + return False + if attrs.output_padding and any([x != 0 for x in map(int, attrs.output_padding)]): + logger.info("nn.conv3d_transpose: output padding is not supported.") + return False + return True + + +def is_valid_subgraph(params, body): + """Final check on whether the subgraph is valid and should be offloaded to TensorRT.""" + # Remove invalid subgraphs for implicit batch mode. + if get_tensorrt_use_implicit_batch_mode(): + input_batch_sizes = [] + for var in params: + # In implicit batch mode, all inputs must have same batch size + if isinstance(var.checked_type, relay.TupleType): + for tupe_type in var.checked_type.fields: + # Scalar inputs not allowed + if len(tupe_type.shape) == 0: + logger.info("tensorrt: scalar inputs not supported") + return False + input_batch_sizes.append(int(tupe_type.shape[0])) + else: + # Scalar inputs not allowed + if len(var.checked_type.shape) == 0: + logger.info("tensorrt: scalar inputs not supported") + return False + input_batch_sizes.append(int(var.checked_type.shape[0])) + if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: + logger.info("tensorrt: inputs have different batch sizes") + return False + # Remove subgraphs with no multiply-accumulates + if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0: + return False + return True + + +def prune_tensorrt_subgraphs(mod): + """ + Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs + is set). + """ + + class SubgraphRemover(ExprMutator): + """ + Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. + """ + + def __init__(self, subgraphs_to_remove, mod, new_mod): + ExprMutator.__init__(self) + self.subgraphs_to_remove = subgraphs_to_remove + self.mod = mod + self.new_mod = new_mod + + def visit_call(self, call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if name in self.subgraphs_to_remove: + # "Inline" the subgraph back into new main function. + func = self.mod[name] + var_map = {} + for arg, param in zip(call.args, func.params): + var_map[param] = super().visit(arg) + new_body = relay.bind(func.body, var_map) + return new_body + if name != "main": + # Copy the GlobalVar (subgraph function) to the new module and call. + args = [] + for arg in call.args: + args.append(super().visit(arg)) + subgraph_gv = relay.GlobalVar(name) + self.new_mod[subgraph_gv] = self.mod[name] + return subgraph_gv(*args) + return super().visit_call(call) + + subgraphs_to_remove = [] + # Remove invalid subgraphs + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt": + continue + if not is_valid_subgraph(mod[name].params, mod[name].body): + subgraphs_to_remove.append(name) + # Create new pruned module + new_mod = tvm.IRModule() + new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) + return new_mod + + +class RemoveDropout(ExprMutator): + """ + Removes all nn.dropout from an expr. + """ + + def visit_tuple_getitem(self, op): + visit = super().visit_tuple_getitem(op) + if ( + isinstance(visit.tuple_value, Call) + and visit.tuple_value.op.name == "nn.dropout" + and visit.index == 0 + ): + return visit.tuple_value.args[0] + return visit + + +@transform.function_pass(opt_level=0) +class RemoveDropoutPass: + def transform_function(self, func, mod, _): + return RemoveDropout().visit(func) diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 000000000000..f692da3f31ac --- /dev/null +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/tensorrt/codegen.cc + * \brief Implementation of the TensorRT JSON serializer. + */ +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" +#include "../codegen_json/codegen_json.h" + +#if TVM_GRAPH_RUNTIME_TENSORRT +#include "NvInfer.h" +#endif + +namespace tvm { +namespace relay { +namespace contrib { + +/*! \brief Attributes to store the compiler options for TensorRT. */ +struct TensorRTCompilerConfigNode : public tvm::AttrsNode { + Array tensorrt_version; + bool use_implicit_batch; + size_t max_workspace_size; + bool remove_no_mac_subgraphs; + + TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "ext.attrs.TensorRTCompilerConfigNode") { + TVM_ATTR_FIELD(tensorrt_version) + .describe("TensorRT version as (major, minor, patch).") + .set_default(Array({6, 0, 1})); + TVM_ATTR_FIELD(use_implicit_batch).set_default(true); + TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); + TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); + } +}; + +class TensorRTCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, + TensorRTCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig); + +/*! + * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a + * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until + * runtime. + */ +class TensorRTJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + TensorRTJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr) {} + + std::vector VisitExpr_(const CallNode* cn) { + std::string name; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else { + return JSONSerializer::VisitExpr_(cn); + } + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + if (name == "nn.pad") { + SetPadNodeAttribute(node, cn); + } else if (name == "strided_slice") { + SetStridedSliceNodeAttribute(node, cn); + } else { + SetCallNodeAttribute(node, cn); + } + // These attributes are global to the whole module. + SaveGlobalAttributes(node); + return AddNode(node, GetRef(cn)); + } + + void SetPadNodeAttribute(std::shared_ptr node, const CallNode* cn) { + const auto* pad_attr = cn->attrs.as(); + CHECK(pad_attr); + auto p = pad_attr->pad_width; + const int dim_h = (p.size() == 5) ? 3 : 2; + const int dim_w = (p.size() == 5) ? 4 : 3; + std::vector padding = {std::to_string(p[dim_h][0].as()->value), + std::to_string(p[dim_w][0].as()->value), + std::to_string(p[dim_h][1].as()->value), + std::to_string(p[dim_w][1].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + node->SetAttr("padding", padding_attr); + } + + void SetStridedSliceNodeAttribute(std::shared_ptr node, const CallNode* cn) { + const auto* attrs = cn->attrs.as(); + CHECK(attrs && attrs->begin && attrs->end && attrs->strides) + << "StridedSlice must have static begin, end, and strides."; + const bool default_strides = + !attrs->strides.value().defined() || attrs->strides.value().size() == 0; + auto ishape = backend::GetShape(cn->args[0]->checked_type()); + + auto process_slice_index = [](Integer x, int default_value, int dim_value) { + if (!x.defined()) return default_value; + int value = x.as()->value; + if (value < 0) value += dim_value; + return value; + }; + + std::vector start, size, strides; + for (size_t i = 0; i < attrs->begin.value().size(); ++i) { + const int begin_value = process_slice_index(attrs->begin.value()[i], 0, ishape[i]); + const int end_value = process_slice_index(attrs->end.value()[i], ishape[i], ishape[i]); + const int stride_value = (default_strides || i >= attrs->strides.value().size() || + !attrs->strides.value()[i].defined()) + ? 1 + : attrs->strides.value()[i].as()->value; + CHECK_GT(stride_value, 0); + const int size_value = (end_value - begin_value + stride_value - 1) / stride_value; + CHECK_GE(begin_value, 0); + CHECK_GT(size_value, 0); + start.push_back(std::to_string(begin_value)); + size.push_back(std::to_string(size_value)); + strides.push_back(std::to_string(stride_value)); + } + std::vector start_attr, size_attr, strides_attr; + start_attr.emplace_back(start); + size_attr.emplace_back(size); + strides_attr.emplace_back(strides); + node->SetAttr("start", start_attr); + node->SetAttr("size", size_attr); + node->SetAttr("strides", strides_attr); + } + + void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.tensorrt.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + CHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + std::vector tensorrt_version = {std::to_string(cfg.value()->tensorrt_version[0]), + std::to_string(cfg.value()->tensorrt_version[1]), + std::to_string(cfg.value()->tensorrt_version[2])}; + std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; + std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr; + tensorrt_version_attr.emplace_back(tensorrt_version); + use_implicit_batch_attr.emplace_back(use_implicit_batch); + max_workspace_size_attr.emplace_back(max_workspace_size); + node->SetAttr("tensorrt_version", tensorrt_version_attr); + node->SetAttr("use_implicit_batch", use_implicit_batch_attr); + node->SetAttr("max_workspace_size", max_workspace_size_attr); + } +}; + +/*! + * \brief Create a runtime module for TensorRT. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module TensorRTCompiler(const ObjectRef& ref) { + CHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + TensorRTJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + CHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relay.ext.tensorrt").set_body_typed(TensorRTCompiler); + +/*! + * \brief Check whether TensorRT graph runtime is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsTensorRTRuntimeEnabled() { +#if TVM_GRAPH_RUNTIME_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_RUNTIME_TENSORRT +} + +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetTensorRTVersion() { +#if TVM_GRAPH_RUNTIME_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_RUNTIME_TENSORRT +} + +TVM_REGISTER_GLOBAL("relay.op.is_tensorrt_runtime_enabled") + .set_body_typed(IsTensorRTRuntimeEnabled); +TVM_REGISTER_GLOBAL("relay.op.get_tensorrt_version").set_body_typed(GetTensorRTVersion); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc new file mode 100644 index 000000000000..bf0dbfe724ed --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -0,0 +1,222 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_builder.cc + * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine + * which can be used for inference. + */ + +#include "tensorrt_builder.h" + +#include + +#include +#include + +#include "tensorrt_logger.h" +#include "tensorrt_ops.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, + bool use_implicit_batch, bool use_fp16, int batch_size) + : max_workspace_size_(max_workspace_size), + use_implicit_batch_(use_implicit_batch), + use_fp16_(use_fp16), + batch_size_(batch_size) { + // Create TRT builder and network. + builder_ = nvinfer1::createInferBuilder(*logger); +#if TRT_VERSION_GE(6, 0, 1) + // Use INetworkV2. + auto flags = + 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + if (use_implicit_batch_) { + flags = 0U; + builder_->setMaxBatchSize(batch_size_); + } + network_ = builder_->createNetworkV2(flags); +#else + // Use INetwork with implicit batch. + builder_->setMaxBatchSize(batch_size_); + builder_->setMaxWorkspaceSize(max_workspace_size_); + builder_->setFp16Mode(use_fp16_); + network_ = builder_->createNetwork(); +#endif +} + +void TensorRTBuilder::AddInput(int nid, const JSONGraphNode& node) { + auto node_name = node.GetOpName(); + auto shapes = node.GetOpShape(); + auto dtypes = node.GetOpDataType(); + CHECK_EQ(shapes.size(), dtypes.size()); + node_output_map_[nid] = {}; + for (size_t i = 0; i < shapes.size(); ++i) { + const std::string name = node_name + "_" + std::to_string(i); + auto shape = shapes[i]; + // Remove batch dim when not in explicit batch mode. + if (use_implicit_batch_ && shape.size() > 1) { + shape.erase(shape.begin()); + } + nvinfer1::Dims dims = VectorToTrtDims(shape); + CHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; + auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); + node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); + network_input_names_.push_back(input_tensor->getName()); + } +} + +void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) { + nvinfer1::Weights weight = GetDLTensorAsWeights(data, kDLCPU); + std::vector shape(data->shape, data->shape + data->ndim); + // Remove batch dim when not in explicit batch mode. + if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) { + shape.erase(shape.begin()); + } + node_output_map_[nid] = {TensorRTOpInput(weight, shape)}; +} + +void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node) { + auto it = node_output_map_.find(node.id_); + CHECK(it != node_output_map_.end()) << "Output was not found."; + auto out_tensor = it->second[node.index_].tensor; + std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size()); + out_tensor->setName(name.c_str()); + network_->markOutput(*out_tensor); + network_output_names_.push_back(out_tensor->getName()); +} + +void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { + TensorRTOpConverterParams params(network_, node, &trt_weights_); + // Look up converter. + auto it = GetOpConverters()->find(params.op_name); + CHECK(it != GetOpConverters()->end()) + << "Unsupported operator conversion to TRT, op name: " << params.op_name; + const auto converter = it->second; + // Get inputs. + for (size_t i = 0; i < node.GetInputs().size(); ++i) { + auto in_node = node.GetInputs()[i]; + auto it = node_output_map_.find(in_node.id_); + CHECK(it != node_output_map_.end()) << "Input was not found."; + auto input = it->second[in_node.index_]; + if (!converter->variable_input_count) { + if (converter->input_types[i] == kTensor && input.type == kWeight) { + input = TensorRTOpInput(GetInputAsTensor(input)); + } else if (converter->input_types[i] == kWeight && input.type == kTensor) { + LOG(FATAL) << "Input " << i << " for " << params.op_name + << " requires weights but got a tensor."; + } + } + params.inputs.push_back(input); + } + CHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size()) + << "Op expected a different number of inputs."; + + // Convert op to TRT. + converter->Convert(¶ms); + + // Get outputs. + node_output_map_[nid] = {}; + for (auto out : params.outputs) { + node_output_map_[nid].push_back(TensorRTOpInput(out)); + } +} + +TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { + // Process graph to create INetworkDefinition. +// Build engine. +#if TRT_VERSION_GE(6, 0, 1) + config_ = builder_->createBuilderConfig(); + config_->setMaxWorkspaceSize(max_workspace_size_); + if (use_fp16_) { + config_->setFlag(nvinfer1::BuilderFlag::kFP16); + } + // Add profiles. + if (!use_implicit_batch_) { + auto profile = builder_->createOptimizationProfile(); + for (int i = 0; i < network_->getNbInputs(); ++i) { + auto name = network_->getInput(i)->getName(); + auto dims = network_->getInput(i)->getDimensions(); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + } + config_->addOptimizationProfile(profile); + } + nvinfer1::ICudaEngine* engine = builder_->buildEngineWithConfig(*network_, *config_); +#else + nvinfer1::ICudaEngine* engine = builder_->buildCudaEngine(*network_); +#endif + CHECK_EQ(engine->getNbBindings(), network_input_names_.size() + network_output_names_.size()); + nvinfer1::IExecutionContext* context = engine->createExecutionContext(); + CleanUp(); + return {engine, context, network_input_names_, network_output_names_}; +} + +nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, + DLDeviceType src_device) { + CHECK_EQ(dptr->ctx.device_type, src_device); + CHECK(static_cast(dptr->dtype.code) == kDLFloat || + static_cast(dptr->dtype.code) == kDLInt); + const auto trt_dtype = static_cast(dptr->dtype.code) == kDLFloat + ? nvinfer1::DataType::kFLOAT + : nvinfer1::DataType::kINT32; + const size_t weight_bytes = GetDataSize(*dptr); + nvinfer1::Weights weight{trt_dtype, nullptr, 0}; + size_t count = 1; + for (tvm_index_t i = 0; i < dptr->ndim; ++i) { + count *= dptr->shape[i]; + } + CHECK_EQ(count * 4, weight_bytes); + weight.count = count; + weight.values = new float[count]; + CHECK_EQ(TVMArrayCopyToBytes(const_cast(dptr), const_cast(weight.values), + weight_bytes), + 0) + << TVMGetLastError(); + trt_weights_.push_back(weight); + return weight; +} + +nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) { + if (input.type == kTensor) return input.tensor; + auto dims = VectorToTrtDims(input.weight_shape); + return network_->addConstant(dims, input.weight)->getOutput(0); +} + +void TensorRTBuilder::CleanUp() { + network_->destroy(); +#if TRT_VERSION_GE(6, 0, 1) + config_->destroy(); +#endif + builder_->destroy(); + for (auto weight : trt_weights_) { + if (weight.type == nvinfer1::DataType::kFLOAT) { + delete[] static_cast(weight.values); + } else { + delete[] static_cast(weight.values); + } + } +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h new file mode 100644 index 000000000000..efb4d8175650 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -0,0 +1,159 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_builder.h + * \brief The TensorRTBuilder class can be used to convert a JSONRuntime graph into a TRT engine + * which can be used for inference. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ + +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_logger.h" +#include "tensorrt_ops.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + +/*! + * \brief The product of TensorRTBuilder which provides everything needed to + * perform inference. + */ +struct TensorRTEngineAndContext { + nvinfer1::ICudaEngine* engine; + nvinfer1::IExecutionContext* context; + std::vector inputs; + std::vector outputs; +}; + +/*! + * \brief Converts a JSONRuntime graph into a TensorRT engine and execution context. Inputs, + * constants, layers, and outputs can be added to construct the TensorRT network definition. + * BuildEngine() will then use the network definition to build the TensorRT engine and context which + * can be used to run inference - this phase can take a long time because TensorRT will query the + * performance of all available kernels and fusions to optimize the engine. + */ +class TensorRTBuilder { + public: + /*! + * \brief Create TensorRT builder. + * \param logger TensorRT logger to use for errors and warnings. + * \param max_workspace_size Workspace size parameter for TensorRT engine build phase. + * \param use_implicit_batch Whether to use implicit batch mode (default) + * \param use_fp16 Whether to use implicit batch mode (default) + * \param batch_size If use_implicit_batch, + */ + TensorRTBuilder(TensorRTLogger* logger, size_t max_workspace_size, bool use_implicit_batch, + bool use_fp16, int batch_size); + + /*! + * \brief Add TensorRT input(s) for input node in network definition. + * \param nid The input node id. + * \param node The input node. + */ + void AddInput(int nid, const JSONGraphNode& node); + + /*! + * \brief Add TensorRT weight for input constant in network definition. + * \param nid The input node id. + * \param node The data tensor on CPU. + */ + void AddConstant(int nid, const DLTensor* data); + + /*! + * \brief Add TensorRT layer for op node in network definition. + * \param nid The input node id. + * \param node The op node. + */ + void AddLayer(int nid, const JSONGraphNode& node); + + /*! + * \brief Mark TensorRT output in network definition. + * \param entry The output node entry. + */ + void AddOutput(const JSONGraphNodeEntry& entry); + + /*! + * \brief Takes network definition and "compiles" a TensorRT engine which can be used for + * inference. This step is time confusing. + * \return TRT engine, context, and input/output information. + */ + TensorRTEngineAndContext BuildEngine(); + + private: + /*! \brief Convert a DLTensor to a TensorRT weight. */ + nvinfer1::Weights GetDLTensorAsWeights(const DLTensor* dptr, DLDeviceType src_device); + + /*! \brief Convert an input to a Tensor if it is a Weight */ + nvinfer1::ITensor* GetInputAsTensor(const TensorRTOpInput& input); + + /*! \brief Clean up resources used to create engine. */ + void CleanUp(); + + /*! \brief Maps a node to its outputs. */ + std::unordered_map> node_output_map_; + + /*! \brief TensorRT builder. */ + nvinfer1::IBuilder* builder_; + +#if TRT_VERSION_GE(6, 0, 1) + /*! \brief TensorRT builder config. */ + nvinfer1::IBuilderConfig* config_; +#endif + + /*! \brief TensorRT network definition. */ + nvinfer1::INetworkDefinition* network_; + + /*! \brief List of all weights held in memory. */ + std::vector trt_weights_; + + /*! \brief Max workspace size in bytes for TRT. */ + size_t max_workspace_size_; + + /*! \brief Whether to use implicit batch mode. */ + bool use_implicit_batch_; + + /*! \brief Whether to automatically convert model to 16-bit floating point precision. */ + bool use_fp16_; + + /*! \brief Batch size to optimize for. */ + int batch_size_; + + /*! \brief Input names. */ + std::vector network_input_names_; + + /*! \brief Output names. */ + std::vector network_output_names_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h new file mode 100644 index 000000000000..53b6dfeea763 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -0,0 +1,78 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_logger.h + * \brief Contains TensorRTLogger class which is required by TRT and used to + * print info, warnings, and errors. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ + +#include + +#include "NvInfer.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! \brief Logger for TensorRT info/warning/errors. */ +class TensorRTLogger : public nvinfer1::ILogger { + public: + TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {} + explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {} + void log(Severity severity, const char* msg) override { + // suppress messages with severity enum value greater than the reportable + if (severity > reportable_severity) return; + + switch (severity) { + case Severity::kINTERNAL_ERROR: + LOG(ERROR) << "INTERNAL_ERROR: " << msg; + break; + case Severity::kERROR: + LOG(ERROR) << "ERROR: " << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << "WARNING: " << msg; + break; + case Severity::kINFO: + LOG(INFO) << "INFO: " << msg; + break; +#if TRT_VERSION_GE(5, 1, 5) + case Severity::kVERBOSE: + DLOG(INFO) << "VERBOSE: " << msg; + break; +#endif + default: + LOG(INFO) << "UNKNOWN: " << msg; + break; + } + } + + private: + Severity reportable_severity{Severity::kWARNING}; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc new file mode 100644 index 000000000000..a1da6c39f68e --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -0,0 +1,1070 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_ops.cc + * \brief Converters from Relay ops into TensorRT layers. Converters should + * inherit from TensorRTOpConverter and implement the Convert() method. + */ + +#include "tensorrt_ops.h" + +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +TensorRTOpConverter::TensorRTOpConverter(const std::vector& input_types, + bool variable_input_count) + : input_types(input_types), variable_input_count(variable_input_count) {} + +nvinfer1::ITensor* TensorRTOpConverter::Reshape(TensorRTOpConverterParams* params, + nvinfer1::ITensor* input, + const std::vector& new_shape) const { + auto layer = params->network->addShuffle(*input); + CHECK(layer != nullptr); + layer->setReshapeDimensions(VectorToTrtDims(new_shape)); + return layer->getOutput(0); +} + +nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* params, + nvinfer1::ITensor* input, + const std::vector& order) const { + auto layer = params->network->addShuffle(*input); + CHECK(layer != nullptr); + nvinfer1::Permutation perm; + if (TRT_HAS_IMPLICIT_BATCH(params)) { + // Batch dimension cannot be modified. + CHECK_EQ(input->getDimensions().nbDims, order.size() - 1); + CHECK_EQ(order[0], 0); + for (size_t i = 0; i < order.size(); ++i) { + perm.order[i] = order[i + 1] - 1; + } + } else { + CHECK_EQ(input->getDimensions().nbDims, order.size()); + for (size_t i = 0; i < order.size(); ++i) { + perm.order[i] = order[i]; + } + } + layer->setFirstTranspose(perm); + return layer->getOutput(0); +} + +int TensorRTOpConverter::ConvertAxis(TensorRTOpConverterParams* params, int axis, + int input_rank) const { + // Add 1 for missing batch dim. + if (TRT_HAS_IMPLICIT_BATCH(params)) { + input_rank += 1; + } + CHECK(axis >= -input_rank && axis < input_rank); + if (axis < 0) axis += input_rank; + if (TRT_HAS_IMPLICIT_BATCH(params)) { + // Can't modify batch dimenson. + CHECK_NE(axis, 0); + // Subtract 1 for implicit batch dim. + axis -= 1; + } + return axis; +} + +nvinfer1::ITensor* TensorRTOpConverter::CreateScalar( + TensorRTOpConverterParams* params, float value, const nvinfer1::Dims& broadcast_to_dims) const { + nvinfer1::Dims dims; + dims.nbDims = broadcast_to_dims.nbDims; + std::fill_n(dims.d, dims.nbDims, 1); + float* values = new float[1]; + values[0] = value; + nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, static_cast(values), 1}; + params->trt_weights->push_back(weights); + return params->network->addConstant(dims, weights)->getOutput(0); +} + +void TensorRTOpConverter::GetPadding(const std::vector& padding, + bool* use_asymmetric_padding, nvinfer1::DimsHW* prepadding, + nvinfer1::DimsHW* postpadding) const { + CHECK(padding.size() == 1 || padding.size() == 2 || padding.size() == 4); + if (padding.size() == 4) { + // four int : padding width in the order of (top, left, bottom, right). + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[1])); + *postpadding = nvinfer1::DimsHW(std::stoi(padding[2]), std::stoi(padding[3])); + *use_asymmetric_padding = true; + } else if (padding.size() == 2) { + // two int : bottom, right will use same padding as top, left + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[1])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } else { + // one int : same padding used on all sides + *prepadding = nvinfer1::DimsHW(std::stoi(padding[0]), std::stoi(padding[0])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } +} + +void TensorRTOpConverter::GetPadding3D(const std::vector& padding, + bool* use_asymmetric_padding, nvinfer1::Dims* prepadding, + nvinfer1::Dims* postpadding) const { + CHECK(padding.size() == 1 || padding.size() == 3 || padding.size() == 6); + if (padding.size() == 6) { + // six int : padding width in the order of (front, top, left, back, bottom, right) + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[1]), std::stoi(padding[2])); + *postpadding = + nvinfer1::Dims3(std::stoi(padding[3]), std::stoi(padding[4]), std::stoi(padding[5])); + *use_asymmetric_padding = true; + } else if (padding.size() == 3) { + // three int : back, bottom, right will use same padding as front, top, left + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[1]), std::stoi(padding[2])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } else { + // one int : same padding used on all sides + *prepadding = + nvinfer1::Dims3(std::stoi(padding[0]), std::stoi(padding[0]), std::stoi(padding[0])); + *postpadding = *prepadding; + *use_asymmetric_padding = false; + } +} + +class ActivationOpConverter : public TensorRTOpConverter { + public: + ActivationOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"nn.relu", nvinfer1::ActivationType::kRELU}, + {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"tanh", nvinfer1::ActivationType::kTANH}, +#if TRT_VERSION_GE(5, 1, 5) + {"clip", nvinfer1::ActivationType::kCLIP}, + {"nn.leaky_relu", nvinfer1::ActivationType::kLEAKY_RELU}, +#endif + }; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported activation type " << params->op_name; + nvinfer1::IActivationLayer* act_layer = + params->network->addActivation(*params->inputs.at(0).tensor, it->second); +#if TRT_VERSION_GE(5, 1, 5) + if (params->op_name == "clip") { + float a_min = std::stof(params->node.GetAttr>("a_min")[0]); + float a_max = std::stof(params->node.GetAttr>("a_max")[0]); + act_layer->setAlpha(a_min); + act_layer->setBeta(a_max); + } else if (params->op_name == "nn.leaky_relu") { + float alpha = std::stof(params->node.GetAttr>("alpha")[0]); + act_layer->setAlpha(alpha); + } +#endif + CHECK(act_layer != nullptr); + params->outputs.push_back(act_layer->getOutput(0)); + } +}; + +class ElementWiseBinaryOpConverter : public TensorRTOpConverter { + public: + ElementWiseBinaryOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"add", nvinfer1::ElementWiseOperation::kSUM}, + {"subtract", nvinfer1::ElementWiseOperation::kSUB}, + {"multiply", nvinfer1::ElementWiseOperation::kPROD}, + {"divide", nvinfer1::ElementWiseOperation::kDIV}, + {"power", nvinfer1::ElementWiseOperation::kPOW}, + {"maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"minimum", nvinfer1::ElementWiseOperation::kMIN}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported elementwise type " << params->op_name; + // Broadcast + auto input0 = params->inputs.at(0).tensor; + auto input0_dims = TrtDimsToVector(input0->getDimensions()); + auto input1 = params->inputs.at(1).tensor; + auto input1_dims = TrtDimsToVector(input1->getDimensions()); + const bool need_broadcast = input0_dims.size() != input1_dims.size(); + if (need_broadcast) { + if (input0_dims.size() < input1_dims.size()) { + std::vector new_shape(input0_dims); + while (new_shape.size() < input1_dims.size()) new_shape.insert(new_shape.begin(), 1); + input0 = Reshape(params, input0, new_shape); + } else if (input1_dims.size() < input0_dims.size()) { + std::vector new_shape(input1_dims); + while (new_shape.size() < input0_dims.size()) new_shape.insert(new_shape.begin(), 1); + input1 = Reshape(params, input1, new_shape); + } + } + + nvinfer1::IElementWiseLayer* elemwise_layer = + params->network->addElementWise(*input0, *input1, it->second); + CHECK(elemwise_layer != nullptr); + params->outputs.push_back(elemwise_layer->getOutput(0)); + } +}; + +class Conv2DOpConverter : public TensorRTOpConverter { + public: + Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIHW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + int channels = std::stoi(params->node.GetAttr>("channels")[0]); + // TRT conv2d op doesn't support asymmetric padding before 5.1, so we + // workaround by adding a padding layer before the pooling op. + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input_tensor = pad_layer->getOutput(0); + // No need for conv op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(conv_layer != nullptr); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + conv_layer->setPrePadding(prepadding); + conv_layer->setPostPadding(postpadding); +#endif + } else { + conv_layer->setPadding(prepadding); + } + CHECK_EQ(str_strides.size(), 2); + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + conv_layer->setStride(strides); + CHECK_EQ(str_dilation.size(), 2); + const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), std::stoi(str_dilation[1])); + conv_layer->setDilation(dilation); + conv_layer->setNbGroups(groups); + params->outputs.push_back(conv_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Conv3DOpConverter : public TensorRTOpConverter { + public: + Conv3DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCDHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCDHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIDHW"); + auto str_strides = params->node.GetAttr>("strides"); + auto str_dilation = params->node.GetAttr>("dilation"); + auto str_padding = params->node.GetAttr>("padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + + nvinfer1::Dims prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + + // Could use attrs->channels.as()->value + const int num_outputs = weight_shape[0]; + const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto conv_layer = params->network->addConvolutionNd(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(conv_layer != nullptr); + if (use_asymmetric_padding) { + conv_layer->setPrePadding(prepadding); + conv_layer->setPostPadding(postpadding); + } else { + conv_layer->setPaddingNd(prepadding); + } + CHECK_EQ(str_strides.size(), 3); + const auto strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + conv_layer->setStrideNd(strides); + CHECK_EQ(str_dilation.size(), 3); + const auto dilation = nvinfer1::Dims3(std::stoi(str_dilation[0]), std::stoi(str_dilation[1]), + std::stoi(str_dilation[2])); + conv_layer->setDilationNd(dilation); + conv_layer->setNbGroups(groups); + params->outputs.push_back(conv_layer->getOutput(0)); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class DenseOpConverter : public TensorRTOpConverter { + public: + DenseOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + CHECK(input_dims.size() > 0 && input_dims.size() <= 3); + const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + const bool need_reshape_on_input = input_dims.size() != required_rank; + if (need_reshape_on_input) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); + input_tensor = Reshape(params, input_tensor, new_shape); + } + // Weights are in KC format. + CHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); + const int num_units = params->inputs.at(1).weight_shape[0]; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( + *input_tensor, num_units, params->inputs.at(1).weight, bias); + CHECK(fc_layer != nullptr); + auto output_tensor = fc_layer->getOutput(0); + if (need_reshape_on_input) { + // Remove added dims. + input_dims[input_dims.size() - 1] = num_units; + output_tensor = Reshape(params, output_tensor, input_dims); + } + params->outputs.push_back(output_tensor); + } +}; + +class BatchNormOpConverter : public TensorRTOpConverter { + public: + BatchNormOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto gamma = params->inputs.at(1).weight; + auto beta = params->inputs.at(2).weight; + auto mean = params->inputs.at(3).weight; + auto var = params->inputs.at(4).weight; + CHECK_EQ(gamma.count, beta.count); + CHECK_EQ(gamma.count, mean.count); + CHECK_EQ(gamma.count, var.count); + const float epsilon = std::stof(params->node.GetAttr>("epsilon")[0]); + const int axis = std::stoi(params->node.GetAttr>("axis")[0]); + const bool scale = std::stoi(params->node.GetAttr>("scale")[0]); + const bool center = std::stoi(params->node.GetAttr>("center")[0]); + CHECK(axis == 1 || axis == 3); + const bool need_transpose = axis == 3; + + void* weight_scale_ptr = new float[gamma.count]; + nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; + params->trt_weights->push_back(weight_scale); + void* weight_shift_ptr = new float[gamma.count]; + nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; + params->trt_weights->push_back(weight_shift); + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + // fill in the content of weights for the Scale layer + const float* gamma_ptr = reinterpret_cast(gamma.values); + const float* beta_ptr = reinterpret_cast(beta.values); + const float* mean_ptr = reinterpret_cast(mean.values); + const float* var_ptr = reinterpret_cast(var.values); + float* scale_ptr = reinterpret_cast(weight_scale_ptr); + float* shift_ptr = reinterpret_cast(weight_shift_ptr); + for (int i = 0; i < gamma.count; ++i) { + scale_ptr[i] = 1.0 / std::sqrt(var_ptr[i] + epsilon); + if (scale) { + scale_ptr[i] *= gamma_ptr[i]; + } + shift_ptr[i] = -mean_ptr[i] * scale_ptr[i]; + if (center) { + shift_ptr[i] += beta_ptr[i]; + } + } + if (need_transpose) { + input = Transpose(params, input, {0, 3, 1, 2}); + } + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); + CHECK(scale_layer != nullptr); + auto output = scale_layer->getOutput(0); + if (need_transpose) { + output = Transpose(params, output, {0, 2, 3, 1}); + } + params->outputs.push_back(output); + } +}; + +class BatchFlattenOpConverter : public TensorRTOpConverter { + public: + BatchFlattenOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + std::vector new_shape{-1}; + if (!TRT_HAS_IMPLICIT_BATCH(params)) { + new_shape.insert(new_shape.begin(), params->inputs.at(0).tensor->getDimensions().d[0]); + } + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, new_shape)); + } +}; + +class SoftmaxOpConverter : public TensorRTOpConverter { + public: + SoftmaxOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + const int input_rank = input->getDimensions().nbDims; + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_rank); + nvinfer1::ISoftMaxLayer* softmax_layer = params->network->addSoftMax(*input); + softmax_layer->setAxes(1 << axis); + CHECK(softmax_layer != nullptr); + params->outputs.push_back(softmax_layer->getOutput(0)); + } +}; + +class PoolingOpConverter : public TensorRTOpConverter { + public: + PoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + static const std::unordered_map op_map = { + {"nn.max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + auto str_pool_size = params->node.GetAttr>("pool_size"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_strides = params->node.GetAttr>("strides"); + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + bool ceil_mode = std::stoi(params->node.GetAttr>("ceil_mode")[0]); + +// TRT pooling op doesn't support asymmetric padding before 5.1, so we +// workaround by adding a padding layer before the pooling op. +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input = pad_layer->getOutput(0); + // No need for pooling op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + nvinfer1::DimsHW window_size = + nvinfer1::DimsHW(std::stoi(str_pool_size[0]), std::stoi(str_pool_size[1])); + auto pool_layer = params->network->addPooling(*input, it->second, window_size); + CHECK(pool_layer != nullptr); + nvinfer1::DimsHW strides = + nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + pool_layer->setStride(strides); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + pool_layer->setPrePadding(prepadding); + pool_layer->setPostPadding(postpadding); +#endif + } else { + pool_layer->setPadding(prepadding); + } + if (params->op_name == "nn.avg_pool2d") { + bool count_include_pad = + std::stoi(params->node.GetAttr>("count_include_pad")[0]); + // count_include_pad=True is useless if there is no padding. TRT doesn't + // like count_include_pad in combination with strides even when there is + // no padding or assymetric padding even, so turn off inclusive to avoid + // error message. Note: Padding will always be symmetric with + // count_include_pad since partitioner will prevent unsupported case. + if (prepadding.h() == 0 && prepadding.w() == 0) { + count_include_pad = false; + } + pool_layer->setAverageCountExcludesPadding(!count_include_pad); + } +#if TRT_VERSION_GE(5, 1, 5) + if (ceil_mode) { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } +#else + CHECK(!ceil_mode); +#endif + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Pooling3DOpConverter : public TensorRTOpConverter { + public: + Pooling3DOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + static const std::unordered_map op_map = { + {"nn.max_pool3d", nvinfer1::PoolingType::kMAX}, + {"nn.avg_pool3d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCDHW"); + auto str_pool_size = params->node.GetAttr>("pool_size"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_strides = params->node.GetAttr>("strides"); + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + bool ceil_mode = std::stoi(params->node.GetAttr>("ceil_mode")[0]); + nvinfer1::Dims window_size = nvinfer1::Dims3( + std::stoi(str_pool_size[0]), std::stoi(str_pool_size[1]), std::stoi(str_pool_size[2])); + auto pool_layer = params->network->addPoolingNd(*input, it->second, window_size); + CHECK(pool_layer != nullptr); + nvinfer1::Dims strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + pool_layer->setStrideNd(strides); + if (use_asymmetric_padding) { + pool_layer->setPrePadding(prepadding); + pool_layer->setPostPadding(postpadding); + } else { + pool_layer->setPaddingNd(prepadding); + } + if (params->op_name == "nn.avg_pool3d") { + bool count_include_pad = + std::stoi(params->node.GetAttr>("count_include_pad")[0]); + pool_layer->setAverageCountExcludesPadding(!count_include_pad); + } + if (ceil_mode) { + pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class GlobalPoolingOpConverter : public TensorRTOpConverter { + public: + GlobalPoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + static const std::unordered_map op_map = { + {"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; + const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; + auto pool_layer = + params->network->addPooling(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); + CHECK(pool_layer != nullptr); + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +class ExpandDimsOpConverter : public TensorRTOpConverter { + public: + ExpandDimsOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int num_newaxis = + std::stoi(params->node.GetAttr>("num_newaxis")[0]); + const int axis = ConvertAxis(params, original_axis, input_dims.size() + 1); + for (int i = 0; i < num_newaxis; ++i) { + input_dims.insert(input_dims.begin() + axis, 1); + } + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); + } +}; + +class SqueezeOpConverter : public TensorRTOpConverter { + public: + SqueezeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + auto str_axis = params->node.GetAttr>("axis"); + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input_dims.size()); + input_dims[axis] = 0; + } + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), input_dims.end()); + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); + } +}; + +class UnaryOpConverter : public TensorRTOpConverter { + public: + UnaryOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + // The following ops are supported by TRT but don't exist in relay yet: + // recip, tan, sinh, cosh, asin, acos, asinh, acosh, atanh + static const std::unordered_map op_map = { + {"exp", nvinfer1::UnaryOperation::kEXP}, + {"log", nvinfer1::UnaryOperation::kLOG}, + {"sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"abs", nvinfer1::UnaryOperation::kABS}, + {"negative", nvinfer1::UnaryOperation::kNEG}, +#if TRT_VERSION_GE(5, 1, 5) + {"sin", nvinfer1::UnaryOperation::kSIN}, + {"cos", nvinfer1::UnaryOperation::kCOS}, + {"atan", nvinfer1::UnaryOperation::kATAN}, + {"ceil", nvinfer1::UnaryOperation::kCEIL}, + {"floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif + }; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported unary type " << params->op_name; + nvinfer1::IUnaryLayer* unary_layer = + params->network->addUnary(*params->inputs.at(0).tensor, it->second); + CHECK(unary_layer != nullptr); + params->outputs.push_back(unary_layer->getOutput(0)); + } +}; + +class ConcatOpConverter : public TensorRTOpConverter { + public: + ConcatOpConverter() : TensorRTOpConverter({}, /*variable_input_count=*/true) {} + + void Convert(TensorRTOpConverterParams* params) const { + const int num_inputs = params->inputs.size(); + CHECK_GT(num_inputs, 0); + const int input_rank = params->inputs[0].tensor->getDimensions().nbDims; + std::vector input_tensors; + for (auto input : params->inputs) { + CHECK(input.type == kTensor); + CHECK_EQ(input_rank, input.tensor->getDimensions().nbDims); + input_tensors.push_back(input.tensor); + } + + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_rank); + + nvinfer1::IConcatenationLayer* concat_layer = + params->network->addConcatenation(input_tensors.data(), input_tensors.size()); + CHECK(concat_layer != nullptr); + concat_layer->setAxis(axis); + params->outputs.push_back(concat_layer->getOutput(0)); + } +}; + +class BiasAddOpConverter : public TensorRTOpConverter { + public: + BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + CHECK(input_dims.size() > 0 && input_dims.size() <= required_rank); + const bool need_reshape_on_input = input_dims.size() != required_rank; + if (need_reshape_on_input) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); + input_tensor = Reshape(params, input_tensor, new_shape); + } + + nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input_tensor, nvinfer1::ScaleMode::kCHANNEL, params->inputs.at(1).weight, shift, power); + CHECK(scale_layer != nullptr); + auto output_tensor = scale_layer->getOutput(0); + if (need_reshape_on_input) { + // Remove added dims. + output_tensor = Reshape(params, output_tensor, input_dims); + } + params->outputs.push_back(output_tensor); + } +}; + +class Conv2DTransposeOpConverter : public TensorRTOpConverter { + public: + Conv2DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIHW"); + auto str_dilation = params->node.GetAttr>("dilation"); + CHECK(std::stoi(str_dilation[0]) == 1 && std::stoi(str_dilation[1]) == 1); + auto str_strides = params->node.GetAttr>("strides"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_output_padding = params->node.GetAttr>("output_padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + + // TRT deconv op doesn't support asymmetric padding before 5.1, so we + // workaround by adding a padding layer before the pooling op. + nvinfer1::DimsHW prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); +#if !TRT_VERSION_GE(5, 1, 5) + if (use_asymmetric_padding) { + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); + CHECK(pad_layer != nullptr); + input_tensor = pad_layer->getOutput(0); + // No need for conv op to do any padding. + use_asymmetric_padding = false; + prepadding = nvinfer1::DimsHW(0, 0); + } +#endif + + // Could use conv2d_attr->channels.as()->value + const int num_outputs = weight_shape[1]; + const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(deconv_layer != nullptr); + if (use_asymmetric_padding) { +#if TRT_VERSION_GE(5, 1, 5) + deconv_layer->setPrePadding(prepadding); + deconv_layer->setPostPadding(postpadding); +#endif + } else { + deconv_layer->setPadding(prepadding); + } + const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), std::stoi(str_strides[1])); + deconv_layer->setStride(strides); + deconv_layer->setNbGroups(groups); + nvinfer1::ITensor* output = deconv_layer->getOutput(0); + // Output padding. + if (str_output_padding.size()) { + GetPadding(str_output_padding, &use_asymmetric_padding, &prepadding, &postpadding); + if (prepadding.h() != 0 || prepadding.w() != 0 || postpadding.h() != 0 || + postpadding.w() != 0) { + // Output padding for Conv2D transpose is always asymmetric and applied to post only. + prepadding = nvinfer1::DimsHW(0, 0); + auto pad_layer = params->network->addPadding(*output, prepadding, postpadding); + output = pad_layer->getOutput(0); + } + } + params->outputs.push_back(output); + } +}; + +#if TRT_VERSION_GE(6, 0, 1) +class Conv3DTransposeOpConverter : public TensorRTOpConverter { + public: + Conv3DTransposeOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto weight_shape = params->inputs.at(1).weight_shape; + CHECK_EQ(params->node.GetAttr>("data_layout")[0], "NCDHW"); + CHECK(params->node.GetAttr>("out_layout")[0] == "" || + params->node.GetAttr>("out_layout")[0] == "NCDHW"); + CHECK_EQ(params->node.GetAttr>("kernel_layout")[0], "OIDHW"); + auto str_dilation = params->node.GetAttr>("dilation"); + CHECK_EQ(str_dilation.size(), 3); + CHECK(std::stoi(str_dilation[0]) == 1 && std::stoi(str_dilation[1]) == 1 && + std::stoi(str_dilation[2]) == 1); + auto str_strides = params->node.GetAttr>("strides"); + auto str_padding = params->node.GetAttr>("padding"); + auto str_output_padding = params->node.GetAttr>("output_padding"); + int groups = std::stoi(params->node.GetAttr>("groups")[0]); + nvinfer1::Dims prepadding, postpadding; + bool use_asymmetric_padding; + GetPadding3D(str_padding, &use_asymmetric_padding, &prepadding, &postpadding); + + // Could use attrs->channels.as()->value + const int num_outputs = weight_shape[1]; + const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); + CHECK(deconv_layer != nullptr); + if (use_asymmetric_padding) { + deconv_layer->setPrePadding(prepadding); + deconv_layer->setPostPadding(postpadding); + } else { + deconv_layer->setPaddingNd(prepadding); + } + CHECK_EQ(str_strides.size(), 3); + const auto strides = nvinfer1::Dims3(std::stoi(str_strides[0]), std::stoi(str_strides[1]), + std::stoi(str_strides[2])); + deconv_layer->setStrideNd(strides); + deconv_layer->setNbGroups(groups); + nvinfer1::ITensor* output = deconv_layer->getOutput(0); + // Output padding. + if (str_output_padding.size()) { + GetPadding3D(str_output_padding, &use_asymmetric_padding, &prepadding, &postpadding); + // Are any post-padding values non-zero? + CHECK(!std::any_of(postpadding.d, postpadding.d + postpadding.nbDims, [](int x) { + return x != 0; + })) << "TRT does not support padding on 3 dimensions."; + } + params->outputs.push_back(output); + } +}; +#endif // TRT_VERSION_GE(6, 0, 1) + +class TransposeOpConverter : public TensorRTOpConverter { + public: + TransposeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto str_axes = params->node.GetAttr>("axes"); + std::vector order; + for (size_t i = 0; i < str_axes.size(); ++i) { + order.push_back(std::stoi(str_axes[i])); + } + params->outputs.push_back(Transpose(params, input, order)); + } +}; + +class LayoutTransformOpConverter : public TensorRTOpConverter { + public: + LayoutTransformOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto src = params->node.GetAttr>("src_layout")[0]; + auto dst = params->node.GetAttr>("dst_layout")[0]; + std::vector order; + if (src == "NCHW" && dst == "NHWC") { + order = {0, 2, 3, 1}; + } else if (src == "NHWC" && dst == "NCHW") { + order = {0, 3, 1, 2}; + } else if (src == "NDHWC" && dst == "NCDHW") { + order = {0, 4, 1, 2, 3}; + } else if (src == "NCDHW" && dst == "NDHWC") { + order = {0, 2, 3, 4, 1}; + } + params->outputs.push_back(Transpose(params, input, order)); + } +}; + +class ReshapeOpConverter : public TensorRTOpConverter { + public: + ReshapeOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + CHECK_EQ(std::stoi(params->node.GetAttr>("reverse")[0]), false); + auto str_newshape = params->node.GetAttr>("newshape"); + std::vector new_shape; + const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; + for (size_t i = start_index; i < str_newshape.size(); ++i) { + const int value = std::stoi(str_newshape[i]); + CHECK_GE(value, -1); + new_shape.push_back(value); + } + params->outputs.push_back(Reshape(params, input, new_shape)); + } +}; + +class PadOpConverter : public TensorRTOpConverter { + public: + PadOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto str_paddding = params->node.GetAttr>("padding"); + nvinfer1::DimsHW prepadding = + nvinfer1::DimsHW(std::stoi(str_paddding[0]), std::stoi(str_paddding[1])); + nvinfer1::DimsHW postpadding = + nvinfer1::DimsHW(std::stoi(str_paddding[2]), std::stoi(str_paddding[3])); + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); + params->outputs.push_back(pad_layer->getOutput(0)); + } +}; + +class ReduceOpConverter : public TensorRTOpConverter { + public: + ReduceOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + static const std::unordered_map op_map = { + {"sum", nvinfer1::ReduceOperation::kSUM}, + {"prod", nvinfer1::ReduceOperation::kPROD}, + {"max", nvinfer1::ReduceOperation::kMAX}, + {"min", nvinfer1::ReduceOperation::kMIN}, + {"mean", nvinfer1::ReduceOperation::kAVG}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported reduce type " << params->op_name; + + auto input = params->inputs.at(0).tensor; + CHECK_EQ(std::stoi(params->node.GetAttr>("exclude")[0]), false); + bool keepdims = std::stoi(params->node.GetAttr>("keepdims")[0]); + auto str_axis = params->node.GetAttr>("axis"); + // TODO(trevmorr): Support reduce to scalar. + CHECK_GT(str_axis.size(), 0); + uint32_t reduce_axes = 0; + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims); + reduce_axes |= 1 << axis; + } + auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims); + params->outputs.push_back(reduce_layer->getOutput(0)); + } +}; + +#if TRT_VERSION_GE(5, 1, 5) +class StridedSliceOpConverter : public TensorRTOpConverter { + public: + StridedSliceOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input->getDimensions()); + auto str_start = params->node.GetAttr>("start"); + auto str_size = params->node.GetAttr>("size"); + auto str_strides = params->node.GetAttr>("strides"); + std::vector start, size, strides; + std::transform(str_start.begin(), str_start.end(), std::back_inserter(start), + [](const std::string& s) { return std::stoi(s); }); + std::transform(str_size.begin(), str_size.end(), std::back_inserter(size), + [](const std::string& s) { return std::stoi(s); }); + std::transform(str_strides.begin(), str_strides.end(), std::back_inserter(strides), + [](const std::string& s) { return std::stoi(s); }); + if (TRT_HAS_IMPLICIT_BATCH(params)) { + start.erase(start.begin()); + size.erase(size.begin()); + strides.erase(strides.begin()); + } + auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), + VectorToTrtDims(size), VectorToTrtDims(strides)); + params->outputs.push_back(slice_layer->getOutput(0)); + } +}; +#endif + +class AdaptivePoolingOpConverter : public TensorRTOpConverter { + public: + AdaptivePoolingOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input_tensor = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); + static const std::unordered_map op_map = { + {"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + auto it = op_map.find(params->op_name); + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; + CHECK_EQ(params->node.GetAttr>("layout")[0], "NCHW"); + + // This is an approximation of adaptive pooling. Results will not be + // mathematically exact except when output_size is (1, 1). + // Annotation rules will only allow output size of (1, 1). + auto output_size = nvinfer1::DimsHW(1, 1); + const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; + const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; + const auto stride = nvinfer1::DimsHW(h / output_size.h(), w / output_size.w()); + const auto window_size = nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), + w - (output_size.w() - 1) * stride.w()); + auto pool_layer = params->network->addPooling(*input_tensor, it->second, window_size); + CHECK(pool_layer != nullptr); + pool_layer->setStride(stride); + params->outputs.push_back(pool_layer->getOutput(0)); + } +}; + +const std::shared_ptr>> +GetOpConverters() { + static auto map = + std::make_shared>>(); + if (!map->empty()) return map; + map->emplace("nn.relu", std::make_shared()); + map->emplace("sigmoid", std::make_shared()); + map->emplace("tanh", std::make_shared()); + map->emplace("nn.batch_norm", std::make_shared()); + map->emplace("nn.softmax", std::make_shared()); + map->emplace("nn.conv2d", std::make_shared()); + map->emplace("nn.dense", std::make_shared()); + map->emplace("nn.bias_add", std::make_shared()); + map->emplace("add", std::make_shared()); + map->emplace("subtract", std::make_shared()); + map->emplace("multiply", std::make_shared()); + map->emplace("divide", std::make_shared()); + map->emplace("power", std::make_shared()); + map->emplace("maximum", std::make_shared()); + map->emplace("minimum", std::make_shared()); + map->emplace("nn.max_pool2d", std::make_shared()); + map->emplace("nn.avg_pool2d", std::make_shared()); + map->emplace("nn.global_max_pool2d", std::make_shared()); + map->emplace("nn.global_avg_pool2d", std::make_shared()); + map->emplace("exp", std::make_shared()); + map->emplace("log", std::make_shared()); + map->emplace("sqrt", std::make_shared()); + map->emplace("abs", std::make_shared()); + map->emplace("negative", std::make_shared()); + map->emplace("nn.batch_flatten", std::make_shared()); + map->emplace("expand_dims", std::make_shared()); + map->emplace("squeeze", std::make_shared()); + map->emplace("concatenate", std::make_shared()); + map->emplace("nn.conv2d_transpose", std::make_shared()); + map->emplace("transpose", std::make_shared()); + map->emplace("layout_transform", std::make_shared()); + map->emplace("reshape", std::make_shared()); + map->emplace("nn.pad", std::make_shared()); + map->emplace("sum", std::make_shared()); + map->emplace("prod", std::make_shared()); + map->emplace("max", std::make_shared()); + map->emplace("min", std::make_shared()); + map->emplace("mean", std::make_shared()); + map->emplace("nn.adaptive_max_pool2d", std::make_shared()); + map->emplace("nn.adaptive_avg_pool2d", std::make_shared()); +#if TRT_VERSION_GE(5, 1, 5) + map->emplace("clip", std::make_shared()); + map->emplace("nn.leaky_relu", std::make_shared()); + map->emplace("sin", std::make_shared()); + map->emplace("cos", std::make_shared()); + map->emplace("atan", std::make_shared()); + map->emplace("ceil", std::make_shared()); + map->emplace("floor", std::make_shared()); + map->emplace("strided_slice", std::make_shared()); +#endif // TRT_VERSION_GE(5, 1, 5) +#if TRT_VERSION_GE(6, 0, 1) + map->emplace("nn.conv3d", std::make_shared()); + map->emplace("nn.max_pool3d", std::make_shared()); + map->emplace("nn.avg_pool3d", std::make_shared()); + map->emplace("nn.conv3d_transpose", std::make_shared()); +#endif // TRT_VERSION_GE(6, 0, 1) + return map; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h new file mode 100644 index 000000000000..e9871d42146c --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -0,0 +1,207 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/tensorrt_ops.h + * \brief Converters from Relay ops into TensorRT layers. Converters should + * inherit from TensorRTOpConverter and implement the Convert() method. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ + +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "NvInfer.h" +#include "tensorrt_utils.h" + +#if TRT_VERSION_GE(6, 0, 1) +#define TRT_HAS_IMPLICIT_BATCH(params) (params->network->hasImplicitBatchDimension()) +#else +#define TRT_HAS_IMPLICIT_BATCH(params) (true) +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +/*! + * \brief An input to a op may be either kTensor in the case of nvinfer::ITensor* + * or kWeight for nvinfer1::Weights. + */ +enum TensorRTInputType { + kTensor, + kWeight, +}; + +/*! + * \brief An input to a TensorRTOpConverter. The type of the input is either kTensor + * or kWeight. For kTensor, "tensor" contains the input tensor. For kWeight, + * "weight" contains the input weight and "weight_shape" contains the shape. + */ +struct TensorRTOpInput { + /*! \brief If type is kTensor, will store input tensor. */ + nvinfer1::ITensor* tensor; + + /*! \brief If type is kWeight, will store input weight. */ + nvinfer1::Weights weight; + + /*! \brief Whether the input is in tensor or weight. */ + TensorRTInputType type; + + /*! \brief If type is kWeight, will store weight shape. */ + std::vector weight_shape; + + explicit TensorRTOpInput(nvinfer1::ITensor* tensor) + : tensor(tensor), weight({nvinfer1::DataType::kFLOAT, nullptr, 0}), type(kTensor) {} + TensorRTOpInput(nvinfer1::Weights weight, const std::vector& shape) + : tensor(nullptr), weight(weight), type(kWeight), weight_shape(shape) {} +}; + +/*! \brief Parameters to convert an Op from Relay to TensorRT. */ +struct TensorRTOpConverterParams { + /*! \brief The TRT network that the new layer should be added to. */ + nvinfer1::INetworkDefinition* network; + /*! \brief The corresponding serialized node. */ + const JSONGraphNode& node; + /*! \brief The type of op. */ + std::string op_name; + /*! \brief Inputs to the op. */ + std::vector inputs; + /*! \brief Outputs of the op should be populated here during Convert(). */ + std::vector outputs; + /*! \brief Any newly allocated weights should be stored here also. */ + std::vector* trt_weights; + + TensorRTOpConverterParams(nvinfer1::INetworkDefinition* network, const JSONGraphNode& node, + std::vector* trt_weights) + : network(network), node(node), trt_weights(trt_weights) { + op_name = node.GetOpName(); + } +}; + +/*! \brief Base class for an op converter from Relay to TRT. */ +class TensorRTOpConverter { + public: + /*! \brief Used to specify whether each input is tensor or weight. */ + const std::vector input_types; + /*! \brief If set to true, any number of tensor inputs can be used for the op. + */ + const bool variable_input_count; + + /*! + * \brief Converter subclasses should call this constructor to set + * input_types or variable_input_count. + * \param input_types For each input to the op, there should be a + * corresponding entry in input_types to determine whether that input should + * be a tensor or a weight. TensorRTBuilder will prepare inputs in + * TensorRTOpConverter according to this. + * \param variable_input_count If the op can have multiple inputs, set this to + * true. input_types vector will be ignored and any number of input tensors + * can be used for this op. All inputs will be tensors and not weights. + */ + explicit TensorRTOpConverter(const std::vector& input_types, + bool variable_input_count = false); + + /*! + * \brief Convert to TRT. Implementation should use inputs and attributes + * from the CallNode to add the corresponding TRT layers to network. Outputs + * should be pushed to outputs vector. + * \param params Parameters for this op. + */ + virtual void Convert(TensorRTOpConverterParams* params) const = 0; + + /*! + * \brief Helper function to reshape a tensor. + * \param params Parameters for this op. + * \param input Tensor to reshape. + * \param new_shape New shape, does not include batch dim. + * \return Reshaped tensor + */ + nvinfer1::ITensor* Reshape(TensorRTOpConverterParams* params, nvinfer1::ITensor* input, + const std::vector& new_shape) const; + + /*! + * \brief Helper function to transpose a tensor. + * \param params Parameters for this op. + * \param input Tensor to transpose. + * \param order New order of axes, does include batch dim. + * \return Transposed tensor + */ + nvinfer1::ITensor* Transpose(TensorRTOpConverterParams* params, nvinfer1::ITensor* input, + const std::vector& order) const; + + /*! + * \brief Helper function to convert an axis to TRT format. + * \param axis Axis from TVM. + * \param input_rank Rank of input, does not include batch dim. + * \return Axis in TRT format. + */ + int ConvertAxis(TensorRTOpConverterParams* params, int axis, int input_rank) const; + + /*! + * \brief Create constant that is broadcastable. + * \param params Parameters for this op. + * \param value Value of scalar. + * \param broadcast_to_dims Dims that scalar should be broadcastable against. + * \return Constant tensor. + */ + nvinfer1::ITensor* CreateScalar(TensorRTOpConverterParams* params, float value, + const nvinfer1::Dims& broadcast_to_dims) const; + + /*! + * \brief Get pre/post padding values from padding attributes array. + * \param padding Serialized padding from op attributes. + * \param padding_is_asymmetric True if both pre and post are needed for asymmetric padding. + * \param prepadding Prepadding value or symmetric padding values if !padding_is_asymmetric. + * \param postpadding Postpadding value if padding_is_asymmetric. + */ + void GetPadding(const std::vector& padding, bool* use_asymmetric_padding, + nvinfer1::DimsHW* prepadding, nvinfer1::DimsHW* postpadding) const; + + /*! + * \brief Get pre/post padding values from padding attributes array for volumetric ops. + * \param padding Serialized padding from op attributes. + * \param padding_is_asymmetric True if both pre and post are needed for asymmetric padding. + * \param prepadding Prepadding value or symmetric padding values if !padding_is_asymmetric. + * \param postpadding Postpadding value if padding_is_asymmetric. + */ + void GetPadding3D(const std::vector& padding, bool* use_asymmetric_padding, + nvinfer1::Dims* prepadding, nvinfer1::Dims* postpadding) const; +}; + +/*! + * \brief Get the map of available TensorRTOpConverters, where the key is the name of the relay op. + * \return Map of TensorRTOpConverters. + */ +const std::shared_ptr>> +GetOpConverters(); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc new file mode 100644 index 000000000000..72c025695f7d --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc + * \brief JSON runtime implementation for TensorRT. + */ + +#include +#include +#include + +#include + +#include "../../file_utils.h" +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_RUNTIME_TENSORRT +#include "NvInfer.h" +#include "tensorrt_builder.h" +#endif + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; + +class TensorRTRuntime : public JSONRuntimeBase { + public: + /*! + * \brief The TensorRT runtime module. Deserialize the provided functions + * on creation and store in the layer cache. + * + * \param symbol_name The name of the function. + * \param graph_json serialized JSON representation of a sub-graph. + * \param const_names The names of each constant in the sub-graph. + */ + explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + use_implicit_batch_(true), + max_workspace_size_(size_t(1) << 30) {} + + /*! + * \brief The type key of the module. + * + * \return module type key. + */ + const char* type_key() const override { return "tensorrt"; } + + /*! + * \brief Initialize runtime. Create TensorRT layer from JSON + * representation. + * + * \param consts The constant params from compiled model. + */ + void Init(const Array& consts) override { + CHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + LoadGlobalAttributes(); + if (GetCachedEnginesFromDisk()) return; + SetupConstants(consts); + BuildEngine(); + CacheEngineToDisk(); + } + + void LoadGlobalAttributes() { + // These settings are global to the entire subgraph. Codegen will add them as attributes to all + // op nodes. Read from first one. + for (size_t i = 0; i < nodes_.size(); ++i) { + if (nodes_[i].HasAttr("use_implicit_batch") && nodes_[i].HasAttr("max_workspace_size")) { + use_implicit_batch_ = + std::stoi(nodes_[i].GetAttr>("use_implicit_batch")[0]); + // Allow max_workspace_size to be overridden at runtime. + size_t runtime_max_workspace_size = + dmlc::GetEnv("TVM_TENSORRT_MAX_WORKSPACE_SIZE", size_t(0)); + if (runtime_max_workspace_size != 0) { + max_workspace_size_ = runtime_max_workspace_size; + } else { + max_workspace_size_ = + std::stoul(nodes_[i].GetAttr>("max_workspace_size")[0]); + } + return; + } + } + } + +#ifdef TVM_GRAPH_RUNTIME_TENSORRT + /*! \brief Run inference using built engine. */ + void Run() override { + auto& engine_and_context = trt_engine_cache_.at(symbol_name_); + auto engine = engine_and_context.engine; + auto context = engine_and_context.context; + std::vector bindings(engine->getNbBindings(), nullptr); + + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + uint32_t eid = EntryID(nid, j); + const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); + int binding_index = engine->getBindingIndex(name.c_str()); + CHECK_NE(binding_index, -1); + bindings[binding_index] = data_entry_[eid]->data; + } + } + } + + for (size_t i = 0; i < outputs_.size(); ++i) { + uint32_t eid = EntryID(outputs_[i]); + const std::string& name = engine_and_context.outputs[i]; + int binding_index = engine->getBindingIndex(name.c_str()); + CHECK_NE(binding_index, -1); + bindings[binding_index] = data_entry_[eid]->data; + } + +#if TRT_VERSION_GE(6, 0, 1) + if (use_implicit_batch_) { + CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; + } else { + CHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; + } +#else + CHECK(context->execute(batch_size_, bindings.data())) << "Running TensorRT failed."; +#endif + } + + private: + /*! + * \brief Build TensorRT engine from JSON representation. + */ + void BuildEngine() { + DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_; + const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); + batch_size_ = GetBatchSize(); + TensorRTBuilder builder(&logger_, max_workspace_size_, use_implicit_batch_, use_fp16, + batch_size_); + + // Add inputs and constants. + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + const auto& node = nodes_[nid]; + std::string name = node.GetOpName(); + if (node.GetOpType() == "input") { + builder.AddInput(nid, node); + } else { + CHECK_EQ(node.GetOpType(), "const"); + uint32_t eid = EntryID(nid, 0); + builder.AddConstant(nid, data_entry_[eid]); + } + } + + // Add layers. + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() != "kernel") continue; + builder.AddLayer(nid, node); + } + + // Add outputs. + for (size_t i = 0; i < outputs_.size(); ++i) { + builder.AddOutput(outputs_[i]); + } + + // Build engine. + trt_engine_cache_[symbol_name_] = builder.BuildEngine(); + DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_; + } + + /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for + * already built TRT engines and load into trt_engine_cache_ so they don't + * have to be built at first inference. + */ + bool GetCachedEnginesFromDisk() { + std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); + if (cache_dir.empty()) return false; + std::string key = GetSubgraphKey(); + std::string path = cache_dir + "/" + key + ".plan"; + // Check if engine is in the cache. + std::ifstream infile(path, std::ios::binary); + if (!infile.good()) return false; + DLOG(INFO) << "Loading cached TensorRT engine from " << path; + infile.close(); + std::string serialized_engine; + LoadBinaryFromFile(path, &serialized_engine); + // Deserialize engine + nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_); + TensorRTEngineAndContext engine_and_context; + engine_and_context.engine = + runtime->deserializeCudaEngine(&serialized_engine[0], serialized_engine.size(), nullptr); + engine_and_context.context = engine_and_context.engine->createExecutionContext(); + // Load metadata + std::string meta_path = cache_dir + "/" + key + ".meta"; + std::string serialized_meta; + LoadBinaryFromFile(meta_path, &serialized_meta); + std::istringstream is(serialized_meta); + dmlc::JSONReader reader(&is); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("inputs", &engine_and_context.inputs); + helper.DeclareField("outputs", &engine_and_context.outputs); + helper.ReadAllFields(&reader); + trt_engine_cache_[symbol_name_] = engine_and_context; + return true; + } + + /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will save the engine to that + * directory so it can be loaded later. + */ + void CacheEngineToDisk() { + std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); + if (cache_dir.empty()) return; + std::string key = GetSubgraphKey(); + std::string path = cache_dir + "/" + key + ".plan"; + DLOG(INFO) << "Caching TensorRT engine to " << path; + // Serialize engine to disk + nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize(); + SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), + serialized_engine->size())); + serialized_engine->destroy(); + // Serialize metadata + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.BeginObject(); + writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs); + writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs); + writer.EndObject(); + std::string meta_path = cache_dir + "/" + key + ".meta"; + SaveBinaryToFile(meta_path, os.str()); + } + + std::string GetSubgraphKey() { + // Using this key will only allow a single model per TVM_TENSORRT_CACHE_DIR directory. We could + // instead use a hash of graph_json and all weights to allow many models in the same directory, + // but the cost of computing the hash is high. + return symbol_name_ + (dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false) ? "_fp16" : "_fp32"); + } + + /*! \brief Get the batch size when in implicit_batch mode. */ + int GetBatchSize() { + if (!use_implicit_batch_) return -1; + for (size_t i = 0; i < input_nodes_.size(); ++i) { + auto nid = input_nodes_[i]; + if (nodes_[nid].GetOpType() == "input") { + // Get batch size from first input. + return nodes_[nid].GetOpShape()[0][0]; + } + } + return -1; + } + + /*! \brief Map of function name to TRT engine if built already. */ + std::unordered_map trt_engine_cache_; + + /*! \brief TensorRT logger. */ + TensorRTLogger logger_; + + /*! \brief Batch size that the engine is optimized for. */ + int batch_size_; + +#else + void Run() override { + LOG(FATAL) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; + } + + void BuildEngine() { + LOG(WARNING) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; + } + + bool GetCachedEnginesFromDisk() { return false; } + + void CacheEngineToDisk() {} +#endif + + bool use_implicit_batch_; + + size_t max_workspace_size_; +}; + +runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_utils.h b/src/runtime/contrib/tensorrt/tensorrt_utils.h new file mode 100644 index 000000000000..ab9b169f26d6 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_utils.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file runtime/contrib/tensorrt/utils.h + * \brief Helper functions used by TensorRTBuilder or TensorRTOpConverters. + */ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ + +#include +#include + +#include "NvInfer.h" + +// There is a conflict between cpplint and clang-format-10. +// clang-format off +#define TRT_VERSION_GE(major, minor, patch) \ + ((NV_TENSORRT_MAJOR > major) || (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && NV_TENSORRT_PATCH >= patch)) +// clang-format on + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! + * \brief Helper function to convert an vector to TRT Dims. + * \param vec Vector. + * \return TRT Dims. + */ +template +inline nvinfer1::Dims VectorToTrtDims(const std::vector& vec) { + nvinfer1::Dims dims; + // Dims(nbDims=0, d[0]=1) is used to represent a scalar in TRT. + dims.d[0] = 1; + dims.nbDims = vec.size(); + for (size_t i = 0; i < vec.size(); ++i) { + dims.d[i] = vec[i]; + } + return dims; +} + +/*! + * \brief Helper function to convert TRT Dims to vector. + * \param vec TRT Dims. + * \return Vector. + */ +inline std::vector TrtDimsToVector(const nvinfer1::Dims& dims) { + return std::vector(dims.d, dims.d + dims.nbDims); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_UTILS_H_ diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py new file mode 100644 index 000000000000..6f615397db58 --- /dev/null +++ b/tests/python/contrib/test_tensorrt.py @@ -0,0 +1,905 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import time +import pytest + +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.relay.op.contrib import tensorrt +from tvm.contrib import graph_runtime + + +def skip_codegen_test(): + """Skip test if TensorRT and CUDA codegen are not present""" + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tvm.get_global_func("relay.ext.tensorrt", True): + print("Skip because TensorRT codegen is not available.") + return True + return False + + +def skip_runtime_test(): + if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tensorrt.is_tensorrt_runtime_enabled(): + print("Skip because TensorRT runtime is not available.") + return True + return False + + +def run_and_verify_func(config): + """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. + + Parameters + ---------- + config : Tuple[relay.Function, Dict[str, NDArray], List[str]] + A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and + 3) A list of which vars should be considered params. + """ + if skip_codegen_test(): + return + f, input_shapes, is_param = config + params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param} + input_dict = { + k: np.random.uniform(-1, 1, v).astype(np.float32) + for k, v in input_shapes.items() + if k not in is_param + } + + # Run TRT + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, graph_params = relay.build(mod, "cuda", params=params) + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**graph_params) + mod.run(**input_dict) + results = [mod.get_output(i) for i in range(mod.get_num_outputs())] + + # Run reference + mod = tvm.IRModule() + mod["main"] = f + with tvm.transform.PassContext(opt_level=3): + graph, lib, graph_params = relay.build(mod, "cuda", params=params) + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**graph_params) + mod.run(**input_dict) + ref_results = [mod.get_output(i) for i in range(mod.get_num_outputs())] + + assert len(results) == len(ref_results) + for i in range(len(results)): + res = results[i].asnumpy() + ref_res = ref_results[i].asnumpy() + assert res.shape == ref_res.shape + tvm.testing.assert_allclose(res, ref_res, rtol=1e-3, atol=1e-3) + + +def run_and_verify_model(model): + if skip_codegen_test(): + return + + def compile_and_run(i_data, input_shape, dtype, use_trt=True, num_iteration=1): + import mxnet as mx + from mxnet.gluon.model_zoo.vision import get_model + + def check_trt_used(graph): + import json + + graph = json.loads(graph) + num_trt_subgraphs = sum( + [ + 1 + for n in graph["nodes"] + if n.get("attrs", {}).get("func_name", "").startswith("tensorrt_") + ] + ) + assert num_trt_subgraphs >= 1 + + block = get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + graph, lib, params = relay.build(mod, "cuda", params=params) + check_trt_used(graph) + else: + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, "cuda", params=params) + + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + mod.set_input(**params) + # Warmup + for i in range(10): + mod.run(data=i_data) + # Time + times = [] + for i in range(num_iteration): + start_time = time.time() + mod.run(data=i_data) + res = mod.get_output(0) + times.append(time.time() - start_time) + latency = 1000.0 * np.mean(times) + print(model, latency) + return res + + dtype = "float32" + input_shape = (1, 3, 224, 224) + i_data = np.random.uniform(-1, 1, input_shape).astype(dtype) + res = compile_and_run(i_data, input_shape, dtype, use_trt=True) + if skip_runtime_test(): + return + ref_res = compile_and_run(i_data, input_shape, dtype, use_trt=False) + tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3) + + +def test_tensorrt_simple(): + if skip_codegen_test(): + return + dtype = "float32" + xshape = (1, 3, 2, 2) + yshape = (1, 3, 1, 1) + zshape = (1, 1, 1, 1) + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.var("y", shape=(yshape), dtype=dtype) + z = relay.var("z", shape=(zshape), dtype=dtype) + w = z * (x + y) + out = relay.nn.relu(w) + f = relay.Function([x, y, z], out) + + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, params = relay.build(mod, "cuda") + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + y_data = np.random.uniform(-1, 1, yshape).astype(dtype) + z_data = np.random.uniform(-1, 1, zshape).astype(dtype) + mod.run(x=x_data, y=y_data, z=z_data) + results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] + + +def test_tensorrt_not_compatible(): + if skip_codegen_test(): + return + dtype = "float32" + xshape = (1, 32, 14, 14) + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.add(x, x) + z = relay.erf(y) + out = relay.nn.relu(z) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + graph, lib, params = relay.build(mod, "cuda") + if skip_runtime_test(): + return + mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0)) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + mod.run(x=x_data) + results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] + + +def test_tensorrt_serialize(): + if skip_codegen_test(): + return + import mxnet + from mxnet.gluon.model_zoo.vision import get_model + + block = get_model("resnet18_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet( + block, shape={"data": (1, 3, 224, 224)}, dtype="float32" + ) + # Compile + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + lib = relay.build(mod, "cuda", params=params) + # Serialize + lib.export_library("compiled.so") + # Deserialize + loaded_lib = tvm.runtime.load_module("compiled.so") + # Run + if skip_runtime_test(): + return + gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib["default"](tvm.gpu(0))) + i_data = np.random.uniform(0, 1, (1, 3, 224, 224)).astype("float32") + gen_module.run(data=i_data) + + +def test_conv2d(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + for dilation in [(1, 1), (2, 2)]: + run_and_verify_func( + get_graph( + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + ) + + +def test_conv2d_nhwc(): + def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d( + x, + kernel, + channels=16, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="HWIO", + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_conv2d_weights_const(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.const(np.ones(k_shape).astype("float32")) + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_conv2d_weights_transposed(): + def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + kernel_t = relay.transpose(kernel, order) + # Conv2d requires constant weights in TensorRT, so the weights should be transposed by + # FoldConstant. + out = relay.nn.conv2d(x, kernel_t, channels=k_shape[order[0]], kernel_size=(3, 3)) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_dense(): + def get_graph(x_shape=(1, 16), k_shape=(32, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + # Dense requires constant weights in TensorRT, so the weights are transposed by us. + out = relay.nn.dense(x, kernel, units=k_shape[0]) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + + +def test_bias_add(): + def get_graph(x_shape=(1, 16), channels=16): + x = relay.var("x", shape=(x_shape), dtype="float32") + bias = relay.var("bias", shape=(channels,), dtype="float32") + out = relay.nn.bias_add(x, bias) + f = relay.Function([x, bias], out) + return f, {"x": x_shape, "bias": (channels,)}, ["bias"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph((1, 6, 3, 4), 6)) + + +def test_pool2d(): + def get_graph( + op, + x_shape=(1, 3, 32, 32), + pool_size=(2, 2), + strides=(2, 2), + padding=(0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for pool_size in [(2, 2), (3, 3)]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]: + for ceil_mode in [False, True]: + # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling" + if pool_size == (2, 2) and padding == (0, 0, 1, 1): + continue + for count_include_pad in [False, True]: + # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding" + if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)): + continue + run_and_verify_func( + get_graph( + relay.nn.avg_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + ) + run_and_verify_func( + get_graph( + relay.nn.max_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + ) + + +def test_global_pool2d(): + def get_graph(op, x_shape=(1, 3, 32, 32)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.global_max_pool2d)) + run_and_verify_func(get_graph(relay.nn.global_avg_pool2d)) + + +def test_batch_flatten(): + def get_graph(x_shape=(1, 3, 4, 6)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.batch_flatten(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_expand_dims(): + def get_graph(x_shape=(1, 3), axis=1, num_newaxis=1): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.expand_dims(x, axis, num_newaxis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_squeeze(): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.squeeze(x, axis=axis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 5, 1, 1), (2, 3))) + run_and_verify_func(get_graph((1, 3, 1), (-1,))) + + +def test_concatenate(): + def get_graph(input_shapes, axis): + concat_inputs = [] + shapes_dict = {} + for i in range(len(input_shapes)): + name = "input_{}".format(i) + concat_inputs.append(relay.var(name, shape=(input_shapes[i]), dtype="float32")) + shapes_dict[name] = input_shapes[i] + out = relay.concatenate(concat_inputs, axis) + f = relay.Function(concat_inputs, out) + return f, shapes_dict, [] + + run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1)) + + +def test_conv2d_transpose(): + def get_graph( + x_shape=(1, 32, 8, 8), + k_shape=(32, 16, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv2d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + run_and_verify_func(get_graph(padding=padding, strides=strides)) + + +def test_reshape(): + def get_graph(x_shape, new_shape): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.reshape(x, new_shape) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1, 1, 10), (-1, 10))) + run_and_verify_func(get_graph((1, 10, 2, 3), (1, -1))) + run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6))) + + +def test_transpose(): + def get_graph(x_shape, order): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.transpose(x, order) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 16, 7, 7), [0, 2, 3, 1])) + run_and_verify_func(get_graph((1, 7, 7, 16), [0, 3, 1, 2])) + + +def test_float_const(): + def get_graph(x_shape=(1, 16)): + x = relay.var("x", shape=(x_shape), dtype="float32") + beta = relay.const(1, dtype="float32") + out = relay.multiply(x, beta) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_pad(): + def get_graph(x_shape, pad_width): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.pad(x, pad_width=pad_width) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]])) + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]])) + run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]])) + run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])) + + +def test_softmax(): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.softmax(x, axis=axis) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1000), axis=1)) + run_and_verify_func(get_graph((1, 1000), axis=-1)) + run_and_verify_func(get_graph((1, 3, 4), axis=-2)) + run_and_verify_func(get_graph((1, 3, 4), axis=1)) + + +def test_batch_norm(): + def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): + x = relay.var("x", shape=(x_shape), dtype="float32") + beta = relay.var("beta", shape=(param_shape), dtype="float32") + gamma = relay.var("gamma", shape=(param_shape), dtype="float32") + moving_mean = relay.var("moving_mean", shape=(param_shape), dtype="float32") + moving_var = relay.var("moving_var", shape=(param_shape), dtype="float32") + out, _, _ = relay.nn.batch_norm( + x, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=axis, + center=True, + scale=True, + epsilon=epsilon, + ) + f = relay.Function([x, gamma, beta, moving_mean, moving_var], out) + return ( + f, + { + "x": x_shape, + "beta": param_shape, + "gamma": param_shape, + "moving_mean": param_shape, + "moving_var": param_shape, + }, + ["beta", "gamma", "moving_mean", "moving_var"], + ) + + run_and_verify_func(get_graph((1, 64, 56, 56), (64,))) + run_and_verify_func(get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05)) + + +def test_unary(): + def get_graph(op, x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for op in [ + relay.nn.relu, + relay.sigmoid, + relay.tanh, + relay.exp, + relay.log, + relay.sqrt, + relay.abs, + relay.negative, + relay.sin, + relay.cos, + relay.atan, + relay.ceil, + relay.floor, + ]: + run_and_verify_func(get_graph(op)) + + +def test_clip(): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.clip(x, a_min=-0.2, a_max=0.4) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_leaky_relu(): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.nn.leaky_relu(x, alpha=0.1) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph()) + + +def test_binary(): + def get_graph(op, x_shape, y_shape, y_is_const=False): + x = relay.var("x", shape=(x_shape), dtype="float32") + if y_is_const: + y = relay.const(np.ones(y_shape).astype("float32")) + out = op(x, y) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + y = relay.var("y", shape=(y_shape), dtype="float32") + out = op(x, y) + f = relay.Function([x, y], out) + return f, {"x": x_shape, "y": y_shape}, [] + + for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]: + for y_is_const in [True, False]: + run_and_verify_func(get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const)) + run_and_verify_func(get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const)) + run_and_verify_func(get_graph(op, (1, 10), (10,), y_is_const)) + run_and_verify_func(get_graph(op, (1, 1, 1, 10), (10,), y_is_const)) + run_and_verify_func(get_graph(op, (1, 1, 1), (3,), y_is_const)) + + +def test_reduce(): + def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x, axis=axis, keepdims=keepdims) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]: + for keepdims in [True, False]: + run_and_verify_func(get_graph(op, axis=(1), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(2, 3), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(1, 2), keepdims=keepdims)) + run_and_verify_func(get_graph(op, axis=(1, 2, 3), keepdims=keepdims)) + + +def test_strided_slice(): + def get_graph(x_shape, begin, end, strides=None): + x = relay.var("x", shape=(x_shape), dtype="float32") + if strides: + out = relay.strided_slice( + x, + relay.expr.const(begin, dtype="int32"), + relay.expr.const(end, dtype="int32"), + relay.expr.const(strides, dtype="int32"), + ) + else: + out = relay.strided_slice( + x, + relay.expr.const(begin, dtype="int32"), + relay.expr.const(end, dtype="int32"), + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 3, 6, 7), [0, 0, 0, 0], [1, 1, 6, 7])) + run_and_verify_func(get_graph((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6])) + run_and_verify_func(get_graph((1, 10), [0, 0], [1, 10], [1, 2])) + + +def test_adaptive_pool2d(): + def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x, out_size) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.adaptive_max_pool2d)) + run_and_verify_func(get_graph(relay.nn.adaptive_avg_pool2d)) + + +def test_multiple_outputs(): + def get_graph(): + x = relay.var("x", shape=(1, 3), dtype="float32") + y = relay.var("y", shape=(1, 3), dtype="float32") + z = relay.add(x, y) + w = relay.add(z, y) + out = relay.Tuple((z, w)) + f = relay.Function([x, y], out) + return f, {"x": (1, 3), "y": (1, 3)}, [] + + run_and_verify_func(get_graph()) + + +def test_conv3d(): + def get_graph( + x_shape=(1, 32, 8, 8, 8), + k_shape=(16, 32, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + dilation=(1, 1, 1), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv3d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(padding=(0, 0, 0, 1, 1, 1))) + + +def test_pool3d(): + def get_graph( + op, + x_shape=(1, 3, 8, 32, 32), + pool_size=(2, 2, 2), + strides=(2, 2, 2), + padding=(0, 0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.avg_pool3d)) + run_and_verify_func(get_graph(relay.nn.max_pool3d)) + run_and_verify_func(get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1))) + run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1))) + + +def test_conv3d_transpose(): + def get_graph( + x_shape=(1, 32, 8, 8, 8), + k_shape=(32, 16, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + output_padding=(0, 0, 0), + ): + x = relay.var("x", shape=(x_shape), dtype="float32") + kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + out = relay.nn.conv3d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:5], + groups=groups, + padding=padding, + strides=strides, + output_padding=output_padding, + ) + f = relay.Function([x, kernel], out) + return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] + + run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(strides=(2, 2, 2))) + run_and_verify_func(get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1))) + + +def test_alexnet(): + run_and_verify_model("alexnet") + + +def test_resnet18_v1(): + run_and_verify_model("resnet18_v1") + + +def test_resnet18_v2(): + run_and_verify_model("resnet18_v2") + + +def test_squeezenet(): + run_and_verify_model("squeezenet1.0") + + +def test_mobilenet(): + run_and_verify_model("mobilenet0.25") + + +def test_mobilenet_v2(): + run_and_verify_model("mobilenetv2_0.25") + + +def test_vgg11(): + run_and_verify_model("vgg11") + + +def test_densenet121(): + run_and_verify_model("densenet121") + + +if __name__ == "__main__": + test_tensorrt_not_compatible() + test_tensorrt_simple() + test_tensorrt_serialize() + + # Op tests + test_conv2d() + test_conv2d_nhwc() + test_conv2d_weights_const() + test_conv2d_weights_transposed() + test_dense() + test_bias_add() + test_pool2d() + test_global_pool2d() + test_batch_flatten() + test_expand_dims() + test_squeeze() + test_concatenate() + test_conv2d_transpose() + test_reshape() + test_transpose() + test_float_const() + test_pad() + test_softmax() + test_batch_norm() + test_unary() + test_clip() + test_leaky_relu() + test_binary() + test_reduce() + test_strided_slice() + test_adaptive_pool2d() + test_multiple_outputs() + test_conv3d() + test_pool3d() + test_conv3d_transpose() + + # Integration tests + test_alexnet() + test_resnet18_v1() + test_resnet18_v2() + test_squeezenet() + test_mobilenet() + test_mobilenet_v2() + test_vgg11() + test_densenet121() diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 3fc7351c415f..0072fb59cf11 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -45,5 +45,4 @@ echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake -echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_TENSORRT_CODEGEN ON\) >> config.cmake