Skip to content

Commit f3bdc49

Browse files
committed
feat: Add support for output data types in Interpreter
- Add argument for specification of output data types of TRT engines in the interpreter, to avoid type mismatches at runtime - Add support for output data type provision in the Dynamo compile path, which simultaneously tests the feature via the backend testing and e2e frameworks
1 parent 93eee96 commit f3bdc49

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

py/torch_tensorrt/dynamo/backend/conversion.py

+10
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@ def convert_module(
2727
Returns:
2828
TRTModule or TRTModuleNext
2929
"""
30+
# Specify module output data types to ensure TRT output types agree with
31+
# that of the equivalent Torch module
32+
module_outputs = module(*inputs)
33+
34+
if not isinstance(module_outputs, (list, tuple)):
35+
module_outputs = [module_outputs]
36+
37+
output_dtypes = list(output.dtype for output in module_outputs)
38+
3039
interpreter = TRTInterpreter(
3140
module,
3241
InputTensorSpec.from_tensors(inputs),
3342
explicit_batch_dimension=True,
3443
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
44+
output_dtypes=output_dtypes,
3545
)
3646

3747
interpreter_result = interpreter.run(

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import warnings
33
from datetime import datetime
4+
from packaging import version
45
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
56

67
import numpy
@@ -40,6 +41,7 @@ def __init__(
4041
explicit_batch_dimension: bool = True,
4142
explicit_precision: bool = False,
4243
logger_level=None,
44+
output_dtypes=None,
4345
):
4446
super().__init__(module)
4547

@@ -78,6 +80,9 @@ def __init__(
7880
trt.tensorrt.ITensor, TensorMetadata
7981
] = dict()
8082

83+
# Data types for TRT Module output Tensors
84+
self.output_dtypes = output_dtypes
85+
8186
def validate_input_specs(self):
8287
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
8388
if not self.network.has_implicit_batch_dimension:
@@ -178,13 +183,17 @@ def run(
178183
algorithm_selector: set up algorithm selection for certain layer
179184
timing_cache: enable timing cache for TensorRT
180185
profiling_verbosity: TensorRT logging level
186+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
187+
version_compatible: Provide version forward-compatibility for engine plan files
188+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
189+
searching for more optimization options. TRT defaults to 3
181190
Return:
182191
TRTInterpreterResult
183192
"""
184193
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
185194

186195
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
187-
# force_fp32_output=False.
196+
# force_fp32_output=False. Overriden by specifying output_dtypes
188197
self.output_fp16 = (
189198
not force_fp32_output and lower_precision == LowerPrecision.FP16
190199
)
@@ -224,14 +233,14 @@ def run(
224233
cache = builder_config.create_timing_cache(b"")
225234
builder_config.set_timing_cache(cache, False)
226235

227-
if trt.__version__ >= "8.2":
236+
if version.parse(trt.__version__) >= version.parse("8.2"):
228237
builder_config.profiling_verbosity = (
229238
profiling_verbosity
230239
if profiling_verbosity
231240
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
232241
)
233242

234-
if trt.__version__ >= "8.6":
243+
if version.parse(trt.__version__) >= version.parse("8.6"):
235244
if max_aux_streams is not None:
236245
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
237246
builder_config.max_aux_streams = max_aux_streams
@@ -372,6 +381,11 @@ def output(self, target, args, kwargs):
372381
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
373382
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
374383

384+
if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs):
385+
raise RuntimeError(
386+
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
387+
)
388+
375389
for i, output in enumerate(outputs):
376390
if any(
377391
op_name in output.name.split("_")
@@ -396,6 +410,8 @@ def output(self, target, args, kwargs):
396410
self.network.mark_output(output)
397411
if output_bool:
398412
output.dtype = trt.bool
413+
elif self.output_dtypes is not None:
414+
output.dtype = torch_dtype_to_trt(self.output_dtypes[i])
399415
elif self.output_fp16 and output.dtype == trt.float32:
400416
output.dtype = trt.float16
401417
self._output_names.append(name)

0 commit comments

Comments
 (0)