From 024696b680b1611a2a4419533cd7f133eb155965 Mon Sep 17 00:00:00 2001 From: Sebastian Widmann Date: Mon, 15 Apr 2019 15:48:10 +0200 Subject: [PATCH] first creation tensorrt graph export --- create_optimized_trt_graph.py | 80 +++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 create_optimized_trt_graph.py diff --git a/create_optimized_trt_graph.py b/create_optimized_trt_graph.py new file mode 100644 index 0000000..bc662f8 --- /dev/null +++ b/create_optimized_trt_graph.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- coding:UTF-8 -*- + +import glob +import argparse +import os +import time +import tensorflow as tf + +from config import cfg +from model import RPN3D + +from utils import * +from utils.kitti_loader import iterate_data, sample_test_data + +from tensorflow.contrib import tensorrt as trt +from tensorflow.python.platform import gfile + +parser = argparse.ArgumentParser(description='testing') +parser.add_argument('-n', '--tag', type=str, nargs='?', default='pre_trained_car', + help='set log tag') +parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1, + help='set batch size for each gpu') +parser.add_argument('-o', '--output-path', type=str, nargs='?', + default='./predictions', help='results output dir') +parser.add_argument('-v', '--vis', type=bool, nargs='?', default=False, + help='set the flag to True if dumping visualizations') +args = parser.parse_args() + + +dataset_dir = cfg.DATA_DIR +test_dir = os.path.join(dataset_dir, 'testing') +save_model_dir = os.path.join('.', 'save_model', args.tag) + + +def main(_): + with tf.Graph().as_default(): + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION, + visible_device_list=cfg.GPU_AVAILABLE, + allow_growth=True) + + conf = tf.ConfigProto( + gpu_options=gpu_options, + device_count={ + "GPU": cfg.GPU_USE_COUNT, + }, + allow_soft_placement=True, + ) + + with tf.Session(config=config) as sess: + model = RPN3D( + cls=cfg.DETECT_OBJ, + single_batch_size=args.single_batch_size, + avail_gpus=cfg.GPU_AVAILABLE.split(',') + ) + + nd_names = model.get_output_node_names() + node_list = [] + # we ned the names of the tensor, not of the ops + for nd in nd_names: + node_list.append(nd + ':0') + + calib_graph = load_graph(save_model_dir + "/frozen.pb") + with gfile.FastGFile(save_model_dir + "/frozen.pb", 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + trt_graph = trt.create_inference_graph(input_graph_def=graph_def,outputs=node_list, + max_batch_size=32, + max_workspace_size_bytes=max_workspace_size_bytes, + minimum_segment_size=1, + precision_mode="FP16") + path_new_frozen_pb = save_model_dir + "/newFrozenModel_TRT_.pb" + with gfile.FastGFile(path_new_frozen_pb, 'wb') as fp: + fp.write(trt_graph.SerializeToString()) + print("TRT graph written to path ", path_new_frozen_pb) + + +if __name__ == '__main__': + tf.app.run(main)