-
Notifications
You must be signed in to change notification settings - Fork 138
/
tensorrt_utils.py
101 lines (82 loc) · 3.63 KB
/
tensorrt_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import tensorrt as trt
# TensorRT 7.2.3, context style
def build_engine(model_path, max_batch_size=1, max_workspace_size=1 << 30):
engine_path = model_path[:model_path.rfind('.onnx')] + '.engine'
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = max_batch_size << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(model_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
else:
print("Correctly loaded ONNX model!")
with builder.create_builder_config() as config:
config.max_workspace_size = max_workspace_size
with builder.build_engine(network, config) as engine:
serialized_engine = engine.serialize()
save_engine(serialized_engine, engine_path)
return engine_path
def save_engine(engine, engine_path):
# Serialized engines are not portable across platforms or TensorRT versions.
# Engines are specific to the exact GPU model they were built on
with open(engine_path, "wb") as f:
f.write(engine)
def load_engine(engine_path):
# load the engine from a specific file
with open(engine_path, "rb") as f:
serialized_engine = f.read()
return serialized_engine
def torch_device_from_trt(device):
"""Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
elif device == trt.TensorLocation.HOST:
return torch.device('cpu')
else:
return TypeError('%s is not supported by torch' % device)
def torch_dtype_from_trt(dtype):
"""Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool:
return torch.bool
elif dtype == trt.int8:
return torch.int8
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError('%s is not supported by torch' % dtype)
def inference_trt(engine_binary, input_tensor):
# Because the engine is converted from onnx model,
# the input_names and output_names should be the same as onnx model.
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(engine_binary)
context = engine.create_execution_context()
# Inference
names = [_ for _ in engine]
input_names = list(filter(engine.binding_is_input, names))
output_names = list(set(names) - set(input_names))
num_bindings = len(input_names) + len(output_names)
bindings = [None] * num_bindings
for name in input_names:
input_idx = engine.get_binding_index(name)
context.set_binding_shape(input_idx, tuple(input_tensor.shape))
bindings[input_idx] = input_tensor.contiguous().data_ptr()
outputs = {}
for name in output_names:
output_idx = engine.get_binding_index(name)
dtype = torch_dtype_from_trt(engine.get_binding_dtype(output_idx))
shape = tuple(context.get_binding_shape(output_idx))
device = torch_device_from_trt(engine.get_location(output_idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[name] = output
bindings[output_idx] = output.data_ptr()
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
return outputs