From 6d87a841a968fd16a62b6d8aefc40d17c92804a5 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 10 May 2021 17:02:03 +0800 Subject: [PATCH] add onnx to tensorrt tools --- docs/useful_tools.md | 42 ++++++- tools/onnx2tensorrt.py | 275 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 tools/onnx2tensorrt.py diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 556c531663..8ae19f5bee 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -90,7 +90,7 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend. #### Usage -```python +```bash python tools/ort_test.py \ ${CONFIG_FILE} \ ${ONNX_FILE} \ @@ -164,6 +164,46 @@ Examples: --shape 512 1024 ``` +### Convert to TensorRT (experimental) + +A script to convert [ONNX](https://github.com/onnx/onnx) model to [TensorRT](https://developer.nvidia.com/tensorrt) format. + +Prerequisite + +- install `mmcv-full` with ONNXRuntime custom ops and TensorRT plugins follow [ONNXRuntime in mmcv](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) and [TensorRT plugin in mmcv](https://github.com/open-mmlab/mmcv/blob/master/docs/tensorrt_plugin.md). +- Use [pytorch2onnx](#convert-to-onnx-experimental) to convert the model from PyTorch to ONNX. + +Usage + +```bash +python ${MMSEG_PATH}/tools/onnx2tensorrt.py \ + ${CFG_PATH} \ + ${ONNX_PATH} \ + --trt-file ${OUTPUT_TRT_PATH} \ + --min-shape ${MIN_SHAPE} \ + --max-shape ${MAX_SHAPE} \ + --input-img ${INPUT_IMG} \ + --show \ + --verify +``` + +Description of all arguments + +- `config` : Config file of the model. +- `model` : Path to the input ONNX model. +- `--trt-file` : Path to the output TensorRT engine. +- `--max-shape` : Maximum shape of model input. +- `--min-shape` : Minimum shape of model input. +- `--fp16` : Enable fp16 model conversion. +- `--workspace-size` : Max workspace size in GiB. +- `--input-img` : Image for visualize. +- `--show` : Enable result visualize. +- `--dataset` : Palette provider, `CityscapesDataset` as default. +- `--verify` : Verify the outputs of ONNXRuntime and TensorRT. +- `--verbose` : Whether to verbose logging messages while creating TensorRT engine. Defaults to False. + +**Note**: Only tested on whole mode. + ## Miscellaneous ### Print the entire config diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py new file mode 100644 index 0000000000..203ae82a88 --- /dev/null +++ b/tools/onnx2tensorrt.py @@ -0,0 +1,275 @@ +import argparse +import os +import os.path as osp +from typing import Iterable, Optional, Union + +import matplotlib.pyplot as plt +import mmcv +import numpy as np +import onnxruntime as ort +import torch +from mmcv.ops import get_onnxruntime_op_path +from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, + save_trt_engine) + +from mmseg.apis.inference import LoadImage +from mmseg.datasets import DATASETS +from mmseg.datasets.pipelines import Compose + + +def get_GiB(x: int): + """return x GiB.""" + return x * (1 << 30) + + +def _prepare_input_img(img_path: str, + test_pipeline: Iterable[dict], + shape: Optional[Iterable] = None, + rescale_shape: Optional[Iterable] = None) -> dict: + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = (shape[1], shape[0]) + test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + imgs = data['img'] + img_metas = [i.data for i in data['img_metas']] + + if rescale_shape is not None: + for img_meta in img_metas: + img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) + + mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + + return mm_inputs + + +def _update_input_img(img_list: Iterable, img_meta_list: Iterable): + # update img and its meta list + N = img_list[0].size(0) + img_meta = img_meta_list[0][0] + img_shape = img_meta['img_shape'] + ori_shape = img_meta['ori_shape'] + pad_shape = img_meta['pad_shape'] + new_img_meta_list = [[{ + 'img_shape': + img_shape, + 'ori_shape': + ori_shape, + 'pad_shape': + pad_shape, + 'filename': + img_meta['filename'], + 'scale_factor': + (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, + 'flip': + False, + } for _ in range(N)]] + + return img_list, new_img_meta_list + + +def show_result_pyplot(img: Union[str, np.ndarray], + result: np.ndarray, + palette: Optional[Iterable] = None, + fig_size: Iterable[int] = (15, 10), + opacity: float = 0.5, + title: str = '', + block: bool = True): + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + seg = mmcv.imresize(seg, img.shape[:2][::-1]) + palette = np.array(palette) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + + plt.figure(figsize=fig_size) + plt.imshow(mmcv.bgr2rgb(img)) + plt.title(title) + plt.tight_layout() + plt.show(block=block) + + +def onnx2tensorrt(onnx_file: str, + trt_file: str, + config: dict, + input_config: dict, + fp16: bool = False, + verify: bool = False, + show: bool = False, + dataset: str = 'CityscapesDataset', + workspace_size: int = 1, + verbose: bool = False): + import tensorrt as trt + min_shape = input_config['min_shape'] + max_shape = input_config['max_shape'] + # create trt engine and wraper + opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} + max_workspace_size = get_GiB(workspace_size) + trt_engine = onnx2trt( + onnx_file, + opt_shape_dict, + log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, + fp16_mode=fp16, + max_workspace_size=max_workspace_size) + save_dir, _ = osp.split(trt_file) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + save_trt_engine(trt_engine, trt_file) + print(f'Successfully created TensorRT engine: {trt_file}') + + if verify: + inputs = _prepare_input_img( + input_config['input_path'], + config.data.test.pipeline, + shape=min_shape[2:]) + + imgs = inputs['imgs'] + img_metas = inputs['img_metas'] + img_list = [img[None, :] for img in imgs] + img_meta_list = [[img_meta] for img_meta in img_metas] + # update img_meta + img_list, img_meta_list = _update_input_img(img_list, img_meta_list) + + if max_shape[0] > 1: + # concate flip image for batch test + flip_img_list = [_.flip(-1) for _ in img_list] + img_list = [ + torch.cat((ori_img, flip_img), 0) + for ori_img, flip_img in zip(img_list, flip_img_list) + ] + + # Get results from ONNXRuntime + ort_custom_op_path = get_onnxruntime_op_path() + session_options = ort.SessionOptions() + if osp.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + sess = ort.InferenceSession(onnx_file, session_options) + sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode + onnx_output = sess.run(['output'], + {'input': img_list[0].detach().numpy()})[0][0] + + # Get results from TensorRT + trt_model = TRTWraper(trt_file, ['input'], ['output']) + with torch.no_grad(): + trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) + trt_output = trt_outputs['output'][0].cpu().detach().numpy() + + if show: + dataset = DATASETS.get(dataset) + assert dataset is not None + palette = dataset.PALETTE + + show_result_pyplot( + input_config['input_path'], + (onnx_output[0].astype(np.uint8), ), + palette=palette, + title='ONNXRuntime', + block=False) + show_result_pyplot( + input_config['input_path'], (trt_output[0].astype(np.uint8), ), + palette=palette, + title='TensorRT') + + np.testing.assert_allclose( + onnx_output, trt_output, rtol=1e-03, atol=1e-05) + print('TensorRT and ONNXRuntime output all close.') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMSegmentation models from ONNX to TensorRT') + parser.add_argument('config', help='Config file of the model') + parser.add_argument('model', help='Path to the input ONNX model') + parser.add_argument( + '--trt-file', type=str, help='Path to the output TensorRT engine') + parser.add_argument( + '--max-shape', + type=int, + nargs=4, + default=[1, 3, 400, 600], + help='Maximum shape of model input.') + parser.add_argument( + '--min-shape', + type=int, + nargs=4, + default=[1, 3, 400, 600], + help='Minimum shape of model input.') + parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') + parser.add_argument( + '--workspace-size', + type=int, + default=1, + help='Max workspace size in GiB') + parser.add_argument( + '--input-img', type=str, default='', help='Image for test') + parser.add_argument( + '--show', action='store_true', help='Whether to show output results') + parser.add_argument( + '--dataset', + type=str, + default='CityscapesDataset', + help='Dataset name') + parser.add_argument( + '--verify', + action='store_true', + help='Verify the outputs of ONNXRuntime and TensorRT') + parser.add_argument( + '--verbose', + action='store_true', + help='Whether to verbose logging messages while creating \ + TensorRT engine.') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + + assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' + args = parse_args() + + if not args.input_img: + args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') + + # check arguments + assert osp.exists(args.config), 'Config {} not found.'.format(args.config) + assert osp.exists(args.model), \ + 'ONNX model {} not found.'.format(args.model) + assert args.workspace_size >= 0, 'Workspace size less than 0.' + assert DATASETS.get(args.dataset) is not None, \ + 'Dataset {} does not found.'.format(args.dataset) + for max_value, min_value in zip(args.max_shape, args.min_shape): + assert max_value >= min_value, \ + 'max_shape sould be larger than min shape' + + input_config = { + 'min_shape': args.min_shape, + 'max_shape': args.max_shape, + 'input_path': args.input_img + } + + cfg = mmcv.Config.fromfile(args.config) + onnx2tensorrt( + args.model, + args.trt_file, + cfg, + input_config, + fp16=args.fp16, + verify=args.verify, + show=args.show, + dataset=args.dataset, + workspace_size=args.workspace_size, + verbose=args.verbose)