Skip to content

Commit ff2f9ae

Browse files
committed
fix: add explicit cast for i64 outputs as they may not be supported in
all layers Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent da25720 commit ff2f9ae

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7-
import tensorrt as trt
87
import torch
98
import torch.fx
109
from torch.fx.node import _get_qualified_name
@@ -26,6 +25,7 @@
2625
from torch_tensorrt.fx.observer import Observer
2726
from torch_tensorrt.logging import TRT_LOGGER
2827

28+
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
498498
)
499499

500500
for i, output in enumerate(outputs):
501+
name = f"output{i}"
502+
503+
output_dtype = dtype.unknown
501504
if any(
502505
op_name in output.name.split("_")
503506
for op_name in (
@@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
514517
"any",
515518
)
516519
):
517-
output_bool = True
518-
else:
519-
output_bool = False
520-
name = f"output{i}"
521-
output.name = name
522-
self.ctx.net.mark_output(output)
523-
if output_bool:
524-
output.dtype = trt.DataType.BOOL
520+
output_dtype = dtype.b
525521
elif self.output_dtypes is not None:
526-
output.dtype = self.output_dtypes[i].to(trt.DataType)
522+
if self.output_dtypes[i] == dtype.i64:
523+
output = self.ctx.net.add_cast(
524+
output, dtype.i64.to(trt.DataType)
525+
).get_output(0)
526+
output_dtype = dtype.i64
527+
else:
528+
output_dtype = self.output_dtypes[i]
529+
530+
self.ctx.net.mark_output(output)
531+
if output_dtype is not dtype.unknown:
532+
output.dtype = output_dtype.to(trt.DataType, use_default=True)
533+
output.name = name
527534

528535
self._output_names.append(name)
529536
_LOGGER.debug(

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disallow-untyped-decorators=False
2+
13
import logging
24
import operator
35
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
@@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
858860
allowed_casts = {
859861
torch.float,
860862
torch.int32,
863+
torch.int64,
861864
torch.bool,
862865
torch.int8,
863866
torch.float16,

tests/py/dynamo/conversion/harness.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def run_test(
251251
truncate_double=compilation_settings.truncate_double,
252252
)
253253

254+
_LOGGER.debug(f"Compilation settings: {compilation_settings}")
255+
_LOGGER.debug(f"Inputs: {input_specs}")
256+
_LOGGER.debug(f"Output types: {output_dtypes}")
257+
254258
interp = TRTInterpreter(
255259
mod,
256260
input_specs,

0 commit comments

Comments
 (0)