@@ -293,29 +293,30 @@ def run_tensorrt(
293
293
input_tensors ,
294
294
params ,
295
295
precision ,
296
- is_trt_engine = False ,
297
296
batch_size = 1 ,
298
297
):
299
- engine = None
300
-
301
- # If the model file is a TensorRT engine then directly deserialize and run inference
302
- # else convert the torch module to a TensorRT engine first and then run inference
303
- if not is_trt_engine :
304
- compile_settings = {
305
- "inputs" : input_tensors ,
306
- "enabled_precisions" : {precision_to_dtype (precision )},
307
- "truncate_long_and_double" : params .get ("truncate" , False ),
308
- }
309
-
310
- print ("Converting method to TensorRT engine..." )
311
- with torch .no_grad (), torchtrt .logging .errors ():
312
- model = torchtrt .ts .convert_method_to_trt_engine (
313
- model , "forward" , ** compile_settings
314
- )
315
-
298
+ # Export an ONNX model and convert to TRT
299
+ torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
300
+ logger = trt .Logger (trt .Logger .WARNING )
301
+ builder = trt .Builder (logger )
302
+ network = builder .create_network (
303
+ 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
304
+ )
305
+ parser = trt .OnnxParser (network , logger )
306
+ success = parser .parse_from_file ("./tmp.onnx" )
307
+ if not success :
308
+ raise ValueError ("ONNX conversion failed" )
309
+
310
+ config = builder .create_builder_config ()
311
+ if precision == "fp16" :
312
+ config .set_flag (trt .BuilderFlag .FP16 )
313
+ start_compile = time .time_ns ()
314
+ serialized_engine = builder .build_serialized_network (network , config )
315
+ end_compile = time .time_ns ()
316
+ compile_time_s = (end_compile - start_compile ) / 1e9
316
317
# Deserialize the TensorRT engine
317
- with trt .Logger () as logger , trt . Runtime (logger ) as runtime :
318
- engine = runtime .deserialize_cuda_engine (model )
318
+ with trt .Runtime (logger ) as runtime :
319
+ engine = runtime .deserialize_cuda_engine (serialized_engine )
319
320
320
321
print ("Running TensorRT for precision: " , precision , " batch_size : " , batch_size )
321
322
iters = params .get ("iterations" , 20 )
@@ -350,7 +351,7 @@ def run_tensorrt(
350
351
meas_time = end_time - start_time
351
352
timings .append (meas_time )
352
353
353
- recordStats ("TensorRT" , timings , precision , batch_size )
354
+ recordStats ("TensorRT" , timings , precision , batch_size , compile_time_s )
354
355
355
356
356
357
# Deploys inference run for different backend configurations
@@ -426,11 +427,10 @@ def run(
426
427
)
427
428
elif backend == "tensorrt" :
428
429
run_tensorrt (
429
- model ,
430
+ model_torch ,
430
431
input_tensors ,
431
432
params ,
432
433
precision ,
433
- is_trt_engine ,
434
434
batch_size ,
435
435
)
436
436
elif backend == "dynamo" :
@@ -439,9 +439,6 @@ def run(
439
439
elif backend == "torch_compile" :
440
440
run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
441
441
442
- elif backend == "torch_compile" :
443
- run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
444
-
445
442
elif backend == "inductor" :
446
443
run_inductor (model_torch , input_tensors , params , precision , batch_size )
447
444
0 commit comments