Skip to content

Commit 2d0fa75

Browse files
committed
fix: add explicit cast for i64 outputs as they may not be supported in
all layers Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 77df69d commit 2d0fa75

File tree

4 files changed

+24
-14
lines changed

4 files changed

+24
-14
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7-
import tensorrt as trt
87
import torch
98
import torch.fx
109
from torch.fx.node import _get_qualified_name
@@ -26,6 +25,7 @@
2625
from torch_tensorrt.fx.observer import Observer
2726
from torch_tensorrt.logging import TRT_LOGGER
2827

28+
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
498498
)
499499

500500
for i, output in enumerate(outputs):
501+
name = f"output{i}"
502+
503+
output_dtype = dtype.unknown
501504
if any(
502505
op_name in output.name.split("_")
503506
for op_name in (
@@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
514517
"any",
515518
)
516519
):
517-
output_bool = True
518-
else:
519-
output_bool = False
520-
name = f"output{i}"
521-
output.name = name
522-
self.ctx.net.mark_output(output)
523-
if output_bool:
524-
output.dtype = trt.DataType.BOOL
520+
output_dtype = dtype.b
525521
elif self.output_dtypes is not None:
526-
output.dtype = self.output_dtypes[i].to(trt.DataType)
522+
if self.output_dtypes[i] == dtype.i64:
523+
output = self.ctx.net.add_cast(
524+
output, dtype.i64.to(trt.DataType)
525+
).get_output(0)
526+
output_dtype = dtype.i64
527+
else:
528+
output_dtype = self.output_dtypes[i]
529+
530+
self.ctx.net.mark_output(output)
531+
if output_dtype is not dtype.unknown:
532+
output.dtype = output_dtype.to(trt.DataType, use_default=True)
533+
output.name = name
527534

528535
self._output_names.append(name)
529536
_LOGGER.debug(

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disallow-untyped-decorators=False
2+
13
import logging
24
import operator
35
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
@@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
858860
allowed_casts = {
859861
torch.float,
860862
torch.int32,
863+
torch.int64,
861864
torch.bool,
862865
torch.int8,
863866
torch.float16,

tests/py/dynamo/conversion/harness.py

+4
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def run_test(
251251
truncate_double=compilation_settings.truncate_double,
252252
)
253253

254+
_LOGGER.debug(f"Compilation settings: {compilation_settings}")
255+
_LOGGER.debug(f"Inputs: {input_specs}")
256+
_LOGGER.debug(f"Output types: {output_dtypes}")
257+
254258
interp = TRTInterpreter(
255259
mod,
256260
input_specs,

tests/py/dynamo/models/test_dtype_support.py

-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def forward(self, x):
3939
inputs=[in_tensor],
4040
pass_through_build_failures=True,
4141
truncate_double=True,
42-
output_format="fx",
4342
min_block_size=1,
4443
use_python_runtime=False,
4544
)
@@ -78,7 +77,6 @@ def forward(self, x):
7877
inputs=[in_tensor],
7978
pass_through_build_failures=True,
8079
truncate_double=True,
81-
output_format="fx",
8280
min_block_size=1,
8381
use_python_runtime=True,
8482
)
@@ -123,7 +121,6 @@ def forward(self, x):
123121
inputs=[in_tensor],
124122
pass_through_build_failures=True,
125123
truncate_double=False,
126-
output_format="fx",
127124
min_block_size=1,
128125
use_python_runtime=False,
129126
)
@@ -163,7 +160,6 @@ def forward(self, x):
163160
inputs=[in_tensor],
164161
pass_through_build_failures=True,
165162
truncate_double=False,
166-
output_format="fx",
167163
min_block_size=1,
168164
use_python_runtime=True,
169165
)

0 commit comments

Comments
 (0)