Skip to content

Commit 0afc619

Browse files
peri044gs-olive
authored andcommitted
chore: Add TRT runner via onnx (#2503)
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 5770e00 commit 0afc619

File tree

3 files changed

+27
-28
lines changed

3 files changed

+27
-28
lines changed

tools/perf/benchmark.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ python hub.py
77

88
batch_sizes=(1 2 4 8 16 32 64 128 256)
99
large_model_batch_sizes=(1 2 4 8 16 32 64)
10-
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor")
11-
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor")
10+
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "tensorrt")
11+
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "tensorrt")
1212

1313

1414
# Benchmark VGG16 model

tools/perf/perf_run.py

+23-26
Original file line numberDiff line numberDiff line change
@@ -293,29 +293,30 @@ def run_tensorrt(
293293
input_tensors,
294294
params,
295295
precision,
296-
is_trt_engine=False,
297296
batch_size=1,
298297
):
299-
engine = None
300-
301-
# If the model file is a TensorRT engine then directly deserialize and run inference
302-
# else convert the torch module to a TensorRT engine first and then run inference
303-
if not is_trt_engine:
304-
compile_settings = {
305-
"inputs": input_tensors,
306-
"enabled_precisions": {precision_to_dtype(precision)},
307-
"truncate_long_and_double": params.get("truncate", False),
308-
}
309-
310-
print("Converting method to TensorRT engine...")
311-
with torch.no_grad(), torchtrt.logging.errors():
312-
model = torchtrt.ts.convert_method_to_trt_engine(
313-
model, "forward", **compile_settings
314-
)
315-
298+
# Export an ONNX model and convert to TRT
299+
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
300+
logger = trt.Logger(trt.Logger.WARNING)
301+
builder = trt.Builder(logger)
302+
network = builder.create_network(
303+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
304+
)
305+
parser = trt.OnnxParser(network, logger)
306+
success = parser.parse_from_file("./tmp.onnx")
307+
if not success:
308+
raise ValueError("ONNX conversion failed")
309+
310+
config = builder.create_builder_config()
311+
if precision == "fp16":
312+
config.set_flag(trt.BuilderFlag.FP16)
313+
start_compile = time.time_ns()
314+
serialized_engine = builder.build_serialized_network(network, config)
315+
end_compile = time.time_ns()
316+
compile_time_s = (end_compile - start_compile) / 1e9
316317
# Deserialize the TensorRT engine
317-
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
318-
engine = runtime.deserialize_cuda_engine(model)
318+
with trt.Runtime(logger) as runtime:
319+
engine = runtime.deserialize_cuda_engine(serialized_engine)
319320

320321
print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
321322
iters = params.get("iterations", 20)
@@ -350,7 +351,7 @@ def run_tensorrt(
350351
meas_time = end_time - start_time
351352
timings.append(meas_time)
352353

353-
recordStats("TensorRT", timings, precision, batch_size)
354+
recordStats("TensorRT", timings, precision, batch_size, compile_time_s)
354355

355356

356357
# Deploys inference run for different backend configurations
@@ -426,11 +427,10 @@ def run(
426427
)
427428
elif backend == "tensorrt":
428429
run_tensorrt(
429-
model,
430+
model_torch,
430431
input_tensors,
431432
params,
432433
precision,
433-
is_trt_engine,
434434
batch_size,
435435
)
436436
elif backend == "dynamo":
@@ -439,9 +439,6 @@ def run(
439439
elif backend == "torch_compile":
440440
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
441441

442-
elif backend == "torch_compile":
443-
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
444-
445442
elif backend == "inductor":
446443
run_inductor(model_torch, input_tensors, params, precision, batch_size)
447444

tools/perf/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
numpy
22
argparse
33
pyyaml
4+
onnx
45
transformers==4.33.2
56
diffusers==0.21.4
67
pandas==2.0.1
78
timm==0.9.8
9+

0 commit comments

Comments
 (0)