Skip to content

Commit e109049

Browse files
authored
feat: Add support for passing through build issues in Dynamo compile (#1952)
1 parent d3a47c4 commit e109049

File tree

8 files changed

+36
-10
lines changed

8 files changed

+36
-10
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
19+
PASS_THROUGH_BUILD_FAILURES,
1920
)
2021

2122

@@ -52,7 +53,8 @@ def compile(
5253
logger.warn(
5354
"The Dynamo backend is an experimental feature, for which only the "
5455
+ "following arguments are supported: "
55-
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
56+
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
57+
+ "torch_executed_ops, pass_through_build_failures}"
5658
)
5759

5860
if not isinstance(inputs, collections.abc.Sequence):
@@ -106,6 +108,7 @@ def create_backend(
106108
workspace_size: int = MAX_WORKSPACE_SIZE,
107109
min_block_size: int = MIN_BLOCK_SIZE,
108110
torch_executed_ops: Sequence[str] = set(),
111+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
109112
**kwargs,
110113
):
111114
"""Create torch.compile backend given specified arguments
@@ -118,12 +121,16 @@ def create_backend(
118121
Returns:
119122
Backend for torch.compile
120123
"""
124+
if debug:
125+
logger.setLevel(logging.DEBUG)
126+
121127
settings = CompilationSettings(
122128
debug=debug,
123129
precision=precision,
124130
workspace_size=workspace_size,
125131
min_block_size=min_block_size,
126132
torch_executed_ops=torch_executed_ops,
133+
pass_through_build_failures=pass_through_build_failures,
127134
)
128135

129136
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
77
MIN_BLOCK_SIZE = 5
8+
PASS_THROUGH_BUILD_FAILURES = False

py/torch_tensorrt/dynamo/backend/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEBUG,
88
MAX_WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
10+
PASS_THROUGH_BUILD_FAILURES,
1011
)
1112

1213

@@ -17,3 +18,4 @@ class CompilationSettings:
1718
workspace_size: int = MAX_WORKSPACE_SIZE
1819
min_block_size: int = MIN_BLOCK_SIZE
1920
torch_executed_ops: Sequence[str] = field(default_factory=set)
21+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES

py/torch_tensorrt/dynamo/backend/backends.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import logging
12
from typing import Sequence
23
import torch
3-
import traceback
44
from functools import partial
55
import torch._dynamo as td
66

@@ -19,6 +19,9 @@
1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22+
logger = logging.getLogger(__name__)
23+
24+
2225
@td.register_backend(name="torch_tensorrt")
2326
@fake_tensor_unsupported
2427
def torch_tensorrt_backend(
@@ -75,12 +78,22 @@ def _pretraced_backend(
7578
)
7679
return trt_compiled
7780
except:
78-
traceback.print_exc()
79-
print(
81+
logger.error(
8082
"FX2TRT conversion failed on the subgraph. See trace above. "
81-
+ "Returning GraphModule forward instead."
83+
+ "Returning GraphModule forward instead.",
84+
exc_info=True,
8285
)
83-
return gm.forward
86+
87+
if not settings.pass_through_build_failures:
88+
return gm.forward
89+
else:
90+
raise AssertionError(
91+
"Halting compilation on build failure since "
92+
+ "pass_through_build_failures was specified as True. "
93+
+ "To return the default Torch implementation and avoid "
94+
+ "halting compilation on engine build failures, "
95+
+ "specify pass_through_build_failures=False."
96+
)
8497

8598

8699
def _compile_module(

py/torch_tensorrt/dynamo/backend/test/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def lower_graph_testing(
124124
torch_executed_ops: Sequence[str] = set(),
125125
testing_partitioning: bool = False,
126126
):
127-
"""Helper function to assist with graph lowering for testing of Dynamo torch_compile
127+
"""Helper function to assist with graph lowering for testing of Dynamo compile
128128
129129
Args:
130130
fx_graph: Graph to lower

py/torch_tensorrt/dynamo/common_utils/__init__.py

Whitespace-only changes.

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
from transformers import BertModel
99

10-
from utils import COSINE_THRESHOLD, cosine_similarity
10+
from torch_tensorrt.dynamo.common_utils.test_utils import (
11+
COSINE_THRESHOLD,
12+
cosine_similarity,
13+
)
1114

1215

1316
@pytest.mark.unit
@@ -30,7 +33,7 @@ def test_resnet18(ir):
3033
cos_sim = cosine_similarity(model(input), trt_mod(input))
3134
assert (
3235
cos_sim > COSINE_THRESHOLD,
33-
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
36+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
3437
)
3538

3639
# Clean up model env
@@ -163,7 +166,7 @@ def test_resnet18_half(ir):
163166
cos_sim = cosine_similarity(model(input), trt_mod(input))
164167
assert (
165168
cos_sim > COSINE_THRESHOLD,
166-
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
169+
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
167170
)
168171

169172
# Clean up model env

0 commit comments

Comments
 (0)