Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iluvatar_infer_resnet50 #259

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/benchmarks/resnet50/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@ find ./val -name "*JPEG" | wc -l
| tensorrt | fp16 | 256 |613.4 | 1358.9 | 4469.4 | 1391.4 | 12698.7 | 16.8% | 76.2/76.2 | 19.7/40.0 |
| tensorrt | fp32 | 256 | 474.4 | 1487.3 | 2653.2 | 1560.3 | 6091.6 | 16.1% | 76.2/76.2 | 28.86/40.0 |
| torchtrt | fp16 | 256 | 716.4 | 1370.4 | 4282.6 | 1320.0 | 4723.0 | 6.3% | 76.2/76.2 | 9.42/40.0 |
| ixrt | fp16 | 256 | 136.4 | / | / | 1146.6 | 2679.9 | 11.5% | 76.2 | 4.3/32.0 |
| ixrt | fp16 (W16A32) | 256 | 261.467 | / | / | 1389.332 | 2721.402 | 11.7% | 76.2/76.2 | 8.02/32.0 |
| kunlunxin_xtcl | fp32 | 128 | 311.215 | / | / | 837.507 | 1234.727 | / | 76.2/76.2 | / |

193 changes: 102 additions & 91 deletions inference/inference_engine/iluvatar/ixrt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from ixrt import IxRT, RuntimeConfig, RuntimeContext
import torch
import os
import subprocess
from loguru import logger
import torch
from torch import autocast
import tensorrt as trt

import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import time
import subprocess


class InferModel:
Expand All @@ -16,114 +19,122 @@ def __init__(self, host_mem, device_mem):
self.device = device_mem

def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(
self.device)

def __repr__(self):
return self.__str__()

def __init__(self, config, onnx_path, model):
self.str_to_numpy_dict = {
"int32": np.int32,
"float16": np.float16,
"float32": np.float32,
}
self.engine = self.build_engine(config, onnx_path)
self.outputs = self.allocate_buffers(self.engine)

def config_init_engine(self, config, onnx_path):
quant_file = None

runtime_config = RuntimeConfig()
self.config = config

input_shapes = [config.batch_size, 3, config.image_size, config.image_size]
runtime_config.input_shapes = [("input", input_shapes)]
runtime_config.device_idx = 0
self.logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.logger)

precision = "float16"
if precision == "int8":
assert quant_file, "Quant file must provided for int8 inferencing."

runtime_config.runtime_context = RuntimeContext(
precision,
"nhwc",
use_gpu=True,
pipeline_sync=True,
input_types=config.input_types,
output_types=config.output_types,
input_device="gpu",
output_device="gpu",
)
self.engine = self.build_engine(config, onnx_path)

runtime = IxRT.from_onnx(onnx_path, quant_file, runtime_config)
return runtime
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers(
self.engine)

self.context = self.engine.create_execution_context()
self.numpy_to_torch_dtype_dict = {
bool: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
self.str_to_torch_dtype_dict = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}

def build_engine(self, config, onnx_path):
if config.exist_compiler_path is None:
output_path = config.log_dir + "/" + config.ixrt_tmp_path
trt_path = config.log_dir + "/" + config.ixrt_tmp_path

dir_output_path = os.path.dirname(output_path)
os.makedirs(dir_output_path, exist_ok=True)
dir_trt_path = os.path.dirname(trt_path)
os.makedirs(dir_trt_path, exist_ok=True)

time.sleep(10)

runtime = self.config_init_engine(config, onnx_path)
print(f"Build Engine File: {output_path}")
runtime.BuildEngine()
runtime.SerializeEngine(output_path)
print("Build Engine done!")
trtexec_cmd = "ixrtexec --onnx=" + onnx_path + " --save_engine=" + trt_path
if config.fp16:
trtexec_cmd += " --precision fp16"
if config.has_dynamic_axis:
trtexec_cmd += " --minShapes=" + config.minShapes
trtexec_cmd += " --optShapes=" + config.optShapes
trtexec_cmd += " --maxShapes=" + config.maxShapes

p = subprocess.Popen(trtexec_cmd, shell=True)
p.wait()
else:
output_path = config.exist_compiler_path
print(f"Use existing engine: {output_path}")
trt_path = config.exist_compiler_path

runtime = IxRT()
runtime.LoadEngine(output_path, config.batch_size)
return runtime
with open(trt_path, "rb") as f:
return self.runtime.deserialize_cuda_engine(f.read())

def allocate_buffers(self, engine):
output_map = engine.GetOutputShape()
output_io_buffers = []
output_types = {}
config = engine.GetConfig()
for key, val in config.runtime_context.output_types.items():
output_types[key] = str(val)
for name, shape in output_map.items():
# 1. apply memory buffer for output of the shape
buffer = np.zeros(
shape.dims, dtype=self.str_to_numpy_dict[output_types[name]]
)
buffer = torch.tensor(buffer).cuda()
# 2. put the buffer to a list
output_io_buffers.append([name, buffer, shape])
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()

for binding in range(engine.num_bindings):
size = trt.volume(engine.get_binding_shape(binding))
dtype = trt.nptype(engine.get_binding_dtype(binding))

host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))

if engine.binding_is_input(binding):
inputs.append(self.HostDeviceMem(host_mem, device_mem))
else:
outputs.append(self.HostDeviceMem(host_mem, device_mem))

engine.BindIOBuffers(output_io_buffers)
return output_io_buffers
return inputs, outputs, bindings, stream

def __call__(self, model_inputs: list):
batch_size = np.unique(np.array([i.size(dim=0) for i in model_inputs]))
batch_size = batch_size[0]
input_map = self.engine.GetInputShape()
input_io_buffers = []

for i, model_input in enumerate(model_inputs):
model_input = torch.tensor(model_input.numpy(), dtype=torch.float32).cuda()
if not model_input.is_contiguous():
model_input = model_input.contiguous()
name, shape = list(input_map.items())[0]
_shape, _padding = shape.dims, shape.padding
_shape = [i + j for i, j in zip(_shape, _padding)]
_shape = [_shape[0], *_shape[2:], _shape[1]]
input_io_buffers.append([name, model_input, shape])

self.engine.BindIOBuffers(self.outputs)
self.engine.LoadInput(input_io_buffers)

# torch.cuda.synchronize()
self.engine.Execute()
# torch.cuda.synchronize()

gpu_io_buffers = []
for buffer in self.outputs:
# gpu_io_buffers.append([buffer[0], buffer[1], buffer[2]])
gpu_io_buffers.append(buffer[1])

return gpu_io_buffers, 0
model_input = model_input.cuda()

cuda.memcpy_dtod_async(
self.inputs[i].device,
model_input.data_ptr(),
model_input.element_size() * model_input.nelement(),
self.stream,
)

self.context.execute_async_v2(bindings=self.bindings,
stream_handle=self.stream.handle)
result = []
for out in self.outputs:
out_tensor = torch.empty(out.host.shape, device="cuda").to(
self.str_to_torch_dtype_dict[str(out.host.dtype)])
cuda.memcpy_dtod_async(
out_tensor.data_ptr(),
out.device,
out_tensor.element_size() * out_tensor.nelement(),
self.stream,
)
result.append(out_tensor)

self.stream.synchronize()
return result, 0