1
1
import logging
2
2
import warnings
3
3
from datetime import datetime
4
+ from packaging import version
4
5
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence
5
6
6
7
import numpy
@@ -40,6 +41,7 @@ def __init__(
40
41
explicit_batch_dimension : bool = True ,
41
42
explicit_precision : bool = False ,
42
43
logger_level = None ,
44
+ output_dtypes = None ,
43
45
):
44
46
super ().__init__ (module )
45
47
@@ -78,6 +80,9 @@ def __init__(
78
80
trt .tensorrt .ITensor , TensorMetadata
79
81
] = dict ()
80
82
83
+ # Data types for TRT Module output Tensors
84
+ self .output_dtypes = output_dtypes
85
+
81
86
def validate_input_specs (self ):
82
87
for shape , _ , _ , shape_ranges , has_batch_dim in self .input_specs :
83
88
if not self .network .has_implicit_batch_dimension :
@@ -178,13 +183,17 @@ def run(
178
183
algorithm_selector: set up algorithm selection for certain layer
179
184
timing_cache: enable timing cache for TensorRT
180
185
profiling_verbosity: TensorRT logging level
186
+ max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
187
+ version_compatible: Provide version forward-compatibility for engine plan files
188
+ optimization_level: Builder optimization 0-5, higher levels imply longer build time,
189
+ searching for more optimization options. TRT defaults to 3
181
190
Return:
182
191
TRTInterpreterResult
183
192
"""
184
193
TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
185
194
186
195
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
187
- # force_fp32_output=False.
196
+ # force_fp32_output=False. Overriden by specifying output_dtypes
188
197
self .output_fp16 = (
189
198
not force_fp32_output and lower_precision == LowerPrecision .FP16
190
199
)
@@ -224,14 +233,14 @@ def run(
224
233
cache = builder_config .create_timing_cache (b"" )
225
234
builder_config .set_timing_cache (cache , False )
226
235
227
- if trt .__version__ >= "8.2" :
236
+ if version . parse ( trt .__version__ ) >= version . parse ( "8.2" ) :
228
237
builder_config .profiling_verbosity = (
229
238
profiling_verbosity
230
239
if profiling_verbosity
231
240
else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
232
241
)
233
242
234
- if trt .__version__ >= "8.6" :
243
+ if version . parse ( trt .__version__ ) >= version . parse ( "8.6" ) :
235
244
if max_aux_streams is not None :
236
245
_LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
237
246
builder_config .max_aux_streams = max_aux_streams
@@ -372,6 +381,11 @@ def output(self, target, args, kwargs):
372
381
if not all (isinstance (output , trt .tensorrt .ITensor ) for output in outputs ):
373
382
raise RuntimeError ("TensorRT requires all outputs to be Tensor!" )
374
383
384
+ if self .output_dtypes is not None and len (self .output_dtypes ) != len (outputs ):
385
+ raise RuntimeError (
386
+ f"Specified output dtypes ({ len (self .output_dtypes )} ) differ from number of outputs ({ len (outputs )} )"
387
+ )
388
+
375
389
for i , output in enumerate (outputs ):
376
390
if any (
377
391
op_name in output .name .split ("_" )
@@ -396,6 +410,8 @@ def output(self, target, args, kwargs):
396
410
self .network .mark_output (output )
397
411
if output_bool :
398
412
output .dtype = trt .bool
413
+ elif self .output_dtypes is not None :
414
+ output .dtype = torch_dtype_to_trt (self .output_dtypes [i ])
399
415
elif self .output_fp16 and output .dtype == trt .float32 :
400
416
output .dtype = trt .float16
401
417
self ._output_names .append (name )
0 commit comments