Skip to content

Commit 220ee6a

Browse files
committed
refactor: Enable require_full_compilation in Dynamo
1 parent de70b64 commit 220ee6a

File tree

6 files changed

+58
-5
lines changed

6 files changed

+58
-5
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
1212
TRUNCATE_LONG_AND_DOUBLE = False
13+
REQUIRE_FULL_COMPILATION = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
1414
TRUNCATE_LONG_AND_DOUBLE,
15+
REQUIRE_FULL_COMPILATION,
1516
)
1617

1718

@@ -28,3 +29,4 @@ class CompilationSettings:
2829
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2930
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3031
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
32+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _compile_module(
123123
verbose=settings.debug,
124124
min_block_size=settings.min_block_size,
125125
torch_executed_ops=settings.torch_executed_ops,
126+
require_full_compilation=settings.require_full_compilation,
126127
)
127128

128129
# Store TRT replicas of Torch subgraphs

py/torch_tensorrt/dynamo/compile.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
3333
TRUNCATE_LONG_AND_DOUBLE,
34+
REQUIRE_FULL_COMPILATION,
3435
)
3536

3637

@@ -55,7 +56,7 @@ def compile(
5556
dla_global_dram_size=536870912,
5657
calibrator=None,
5758
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
58-
require_full_compilation=False,
59+
require_full_compilation=REQUIRE_FULL_COMPILATION,
5960
min_block_size=MIN_BLOCK_SIZE,
6061
torch_executed_ops=[],
6162
torch_executed_modules=[],
@@ -73,7 +74,8 @@ def compile(
7374
"The Dynamo backend is an experimental feature, for which only the "
7475
+ "following arguments are supported: "
7576
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
76-
+ "torch_executed_ops, pass_through_build_failures}"
77+
+ "max_aux_streams, version_compatible, optimization_level, "
78+
+ "torch_executed_ops, pass_through_build_failures, require_full_compilation}"
7779
)
7880

7981
if not isinstance(inputs, collections.abc.Sequence):
@@ -111,6 +113,7 @@ def compile(
111113
"optimization_level": optimization_level,
112114
"use_python_runtime": use_python_runtime,
113115
"truncate_long_and_double": truncate_long_and_double,
116+
"require_full_compilation": require_full_compilation,
114117
}
115118

116119
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
7-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
7+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE, REQUIRE_FULL_COMPILATION
88
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
99
from torch.fx.graph_module import GraphModule
1010
from torch.fx.node import _get_qualified_name
@@ -45,6 +45,7 @@ def __init__(
4545
Sequence[str]
4646
] = DEFAULT_SINGLE_NODE_PARTITIONS,
4747
min_block_size=MIN_BLOCK_SIZE,
48+
require_full_compilation=REQUIRE_FULL_COMPILATION,
4849
) -> None:
4950
super().__init__(
5051
graph_module,
@@ -55,6 +56,7 @@ def __init__(
5556
)
5657

5758
self.min_block_size = min_block_size
59+
self.require_full_compilation = require_full_compilation
5860

5961
def propose_partitions(self) -> List[Partition]:
6062
# Propose partitions using the default, then refine the results
@@ -66,6 +68,11 @@ def propose_partitions(self) -> List[Partition]:
6668
self.operator_support, "unsupported_operators", True
6769
)
6870

71+
if not full_support and self.require_full_compilation:
72+
raise AssertionError(
73+
"require_full_compilation=True was specified, but model is not fully supported"
74+
)
75+
6976
# For each partition, determine whether or not the number of computational operators
7077
# exceeds the threshold, and if not, remove that partition
7178
partitions_to_remove = {}
@@ -93,7 +100,7 @@ def propose_partitions(self) -> List[Partition]:
93100
if (
94101
compute_node_count < self.min_block_size
95102
and not exempted_partition
96-
and not full_support
103+
and not (full_support and self.require_full_compilation)
97104
):
98105
partitions_to_remove[id] = compute_node_count
99106

@@ -178,6 +185,7 @@ def partition(
178185
verbose: bool = True,
179186
min_block_size: int = MIN_BLOCK_SIZE,
180187
torch_executed_ops: Sequence[str] = set(),
188+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
181189
) -> torch.fx.GraphModule:
182190
"""Partition an FX GraphModule with aten ops into TRT engines
183191
Partitioning is based on converter operator support
@@ -187,11 +195,17 @@ def partition(
187195
verbose: Bool representing whether to print operator support
188196
min_block_size: Minimum number of operators per TRT-Engine Block
189197
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
198+
require_full_compilation: Whether to require that all operators be run in TRT
190199
Returns:
191200
torch.fx.GraphModule
192201
"""
193202
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
194-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
203+
partitioner = TRTPartitioner(
204+
gm,
205+
supported_ops,
206+
min_block_size=min_block_size,
207+
require_full_compilation=require_full_compilation,
208+
)
195209

196210
# Determine partitions based on user specifications and operator support
197211
# Then, fuse partitions and display overview of supported/unsupported operators

tests/py/dynamo/backend/test_partitioning.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77

88

99
class TestPartitioning(TestCase):
10+
def test_partition_fully_supported_one_op(self):
11+
class FullySupportedOneOp(torch.nn.Module):
12+
def __init__(self, *args, **kwargs) -> None:
13+
super().__init__(*args, **kwargs)
14+
15+
def forward(self, x, y):
16+
return torch.ops.aten.add.Tensor(x, y)
17+
18+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
19+
partitioned_graph = partition(deepcopy(fx_graph))
20+
self.assertEquals(
21+
len(list(partitioned_graph.named_children())),
22+
0,
23+
"Single operators should not be segmented",
24+
)
25+
26+
def test_partition_fully_supported_one_op_require_full_compilation(self):
27+
class FullySupportedOneOp(torch.nn.Module):
28+
def __init__(self, *args, **kwargs) -> None:
29+
super().__init__(*args, **kwargs)
30+
31+
def forward(self, x, y):
32+
return torch.ops.aten.add.Tensor(x, y)
33+
34+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
35+
partitioned_graph = partition(deepcopy(fx_graph), require_full_compilation=True)
36+
self.assertEquals(
37+
len(list(partitioned_graph.named_children())),
38+
1,
39+
"Single operators can be segmented if full compilation is required",
40+
)
41+
1042
def test_partition_fully_supported_multi_op(self):
1143
class FullySupportedMultiOp(torch.nn.Module):
1244
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)