Skip to content

Commit 30ef927

Browse files
committed
fix: Improve logging and kwarg passing in Dynamo
- Improve kwarg passing through Dynamo frontend - Improve logging in case of TRT compilation failures
1 parent 6e4aa0b commit 30ef927

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
1211
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
1312
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
1413
from torch_tensorrt.dynamo.backend._defaults import (
@@ -62,6 +61,10 @@ def compile(
6261

6362
inputs = prepare_inputs(inputs, prepare_device(device))
6463

64+
if not isinstance(enabled_precisions, collections.abc.Collection):
65+
enabled_precisions = [enabled_precisions]
66+
67+
# Parse user-specified enabled precisions
6568
if (
6669
torch.float16 in enabled_precisions
6770
or torch_tensorrt.dtype.half in enabled_precisions
@@ -123,19 +126,12 @@ def create_backend(
123126
Returns:
124127
Backend for torch.compile
125128
"""
126-
if debug:
127-
logger.setLevel(logging.DEBUG)
128-
129-
settings = CompilationSettings(
129+
return partial(
130+
torch_tensorrt_backend,
130131
debug=debug,
131132
precision=precision,
132133
workspace_size=workspace_size,
133134
min_block_size=min_block_size,
134135
torch_executed_ops=torch_executed_ops,
135136
pass_through_build_failures=pass_through_build_failures,
136137
)
137-
138-
return partial(
139-
torch_tensorrt_backend,
140-
settings=settings,
141-
)

py/torch_tensorrt/dynamo/backend/backends.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,12 @@ def _pretraced_backend(
7777
)
7878
return trt_compiled
7979
except:
80-
logger.error(
81-
"FX2TRT conversion failed on the subgraph. See trace above. "
82-
+ "Returning GraphModule forward instead.",
83-
exc_info=True,
84-
)
85-
8680
if not settings.pass_through_build_failures:
81+
logger.warning(
82+
"TRT conversion failed on the subgraph. See trace above. "
83+
+ "Returning GraphModule forward instead.",
84+
exc_info=True,
85+
)
8786
return gm.forward
8887
else:
8988
raise AssertionError(

0 commit comments

Comments
 (0)