Skip to content

Commit 0f3e460

Browse files
committed
feat: Add support for exempting full-support blocks
- When a graph is fully supported, we can ignore the minimum block size argument, which is primarily helpful in reducing segmentation. If the minimum block size is above the number of total operators in the graph, and we support all of those, the whole graph will run in Torch regardless. As a result, we can exempt fully supported graphs from the min block size requirement - Alternatively, if preferable, we can display a warning in such a case, but still respect the minimum block size argument refactor: Add require_full_compilation in Dynamo - Add support for full compilation compilation argument in Dynamo paths feat: Disable require_full_compilation in torch compile
1 parent 036d992 commit 0f3e460

File tree

8 files changed

+171
-14
lines changed

8 files changed

+171
-14
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
USE_PYTHON_RUNTIME = False
1313
USE_FAST_PARTITIONER = True
1414
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
15+
REQUIRE_FULL_COMPILATION = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
OPTIMIZATION_LEVEL,
1111
PASS_THROUGH_BUILD_FAILURES,
1212
PRECISION,
13+
REQUIRE_FULL_COMPILATION,
1314
TRUNCATE_LONG_AND_DOUBLE,
1415
USE_FAST_PARTITIONER,
1516
USE_PYTHON_RUNTIME,
@@ -54,3 +55,4 @@ class CompilationSettings:
5455
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
5556
use_fast_partitioner: bool = USE_FAST_PARTITIONER
5657
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
58+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION

py/torch_tensorrt/dynamo/compile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
OPTIMIZATION_LEVEL,
2020
PASS_THROUGH_BUILD_FAILURES,
2121
PRECISION,
22+
REQUIRE_FULL_COMPILATION,
2223
TRUNCATE_LONG_AND_DOUBLE,
2324
USE_FAST_PARTITIONER,
2425
USE_PYTHON_RUNTIME,
@@ -52,7 +53,7 @@ def compile(
5253
dla_global_dram_size: int = 536870912,
5354
calibrator: object = None,
5455
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
55-
require_full_compilation: bool = False,
56+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
5657
min_block_size: int = MIN_BLOCK_SIZE,
5758
torch_executed_ops: Optional[List[str]] = None,
5859
torch_executed_modules: Optional[List[str]] = None,
@@ -75,8 +76,10 @@ def compile(
7576
"The Dynamo backend is an experimental feature, for which only the "
7677
"following arguments are supported: "
7778
"{enabled_precisions, debug, workspace_size, min_block_size, "
78-
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
79-
"enable_experimental_decompositions}"
79+
"max_aux_streams, version_compatible, optimization_level, "
80+
"torch_executed_ops, pass_through_build_failures, "
81+
"use_fast_partitioner, enable_experimental_decompositions, "
82+
"require_full_compilation}"
8083
)
8184

8285
if not isinstance(inputs, collections.abc.Sequence):
@@ -118,6 +121,7 @@ def compile(
118121
"truncate_long_and_double": truncate_long_and_double,
119122
"use_fast_partitioner": use_fast_partitioner,
120123
"enable_experimental_decompositions": enable_experimental_decompositions,
124+
"require_full_compilation": require_full_compilation,
121125
}
122126

123127
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
_SplitterSettingBase,
1313
)
1414
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
15-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
15+
from torch_tensorrt.dynamo._defaults import (
16+
DEBUG,
17+
MIN_BLOCK_SIZE,
18+
REQUIRE_FULL_COMPILATION,
19+
)
1620
from torch_tensorrt.dynamo.conversion.converter_registry import (
1721
DYNAMO_CONVERTERS as CONVERTERS,
1822
)
@@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
9296
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
9397
Generally useful for module-level exclusion ops which are intensive despite being single functions
9498
min_block_size: Minimum number of computational operators per block
99+
require_full_compilation: Require that all computational operators be run in TRT
95100
Returns:
96101
torch.fx.GraphModule
97102
"""
@@ -104,6 +109,7 @@ def __init__(
104109
Collection[str]
105110
] = DEFAULT_SINGLE_NODE_PARTITIONS,
106111
min_block_size: int = MIN_BLOCK_SIZE,
112+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
107113
):
108114
"""
109115
Preprocesses graph before splitting:
@@ -142,6 +148,7 @@ def __init__(
142148

143149
self.num_trt_accelerated_subgraphs: Optional[int] = None
144150
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
151+
self.require_full_compilation = require_full_compilation
145152

146153
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
147154
"""
@@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
151158
result: List[Subgraph] = []
152159
for subgraph in subgraphs:
153160
if subgraph.is_acc:
154-
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
155-
self.allowed_single_node_partition_ops is not None
156-
and any(
157-
ConverterRegistry.qualified_name_or_str(node.target)
158-
in self.allowed_single_node_partition_ops
159-
for node in subgraph.nodes
161+
if (
162+
len(subgraph.nodes) >= self.settings.min_acc_module_size
163+
or self.require_full_compilation
164+
or (
165+
self.allowed_single_node_partition_ops is not None
166+
and any(
167+
ConverterRegistry.qualified_name_or_str(node.target)
168+
in self.allowed_single_node_partition_ops
169+
for node in subgraph.nodes
170+
)
160171
)
161172
):
162173
result.append(subgraph)
@@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
185196
# Delegate nodes based on operator coverage
186197
subgraphs = self.put_nodes_into_subgraphs()
187198

199+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
200+
full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr(
201+
self.operator_support, "unsupported_operators", True
202+
)
203+
204+
if not full_support and self.require_full_compilation:
205+
raise AssertionError(
206+
"require_full_compilation=True was specified, but model is not fully supported"
207+
)
208+
209+
if (
210+
full_support
211+
and self.require_full_compilation
212+
and self.settings.min_acc_module_size != MIN_BLOCK_SIZE
213+
):
214+
logger.warning(
215+
"Detected both require_full_compilation and min_block_size compilation "
216+
"arguments were specified. Disregarding min_block_size argument for "
217+
"fully supported model."
218+
)
219+
188220
# Remove segments smaller than the block size (with exceptions)
189221
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
190222

@@ -217,6 +249,7 @@ def partition(
217249
verbose: bool = DEBUG,
218250
min_block_size: int = MIN_BLOCK_SIZE,
219251
torch_executed_ops: Collection[Target] = set(),
252+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
220253
) -> torch.fx.GraphModule:
221254
"""Partition an FX GraphModule with aten ops into TRT engines
222255
Partitioning is based on converter operator support
@@ -226,6 +259,7 @@ def partition(
226259
verbose: Bool representing whether to print operator support
227260
min_block_size: Minimum number of operators per TRT-Engine Block
228261
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
262+
require_full_compilation: Require that all computational operators be run in TRT
229263
Returns:
230264
torch.fx.GraphModule
231265
"""
@@ -236,7 +270,12 @@ def partition(
236270

237271
# Construct
238272
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
239-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
273+
partitioner = TRTPartitioner(
274+
gm,
275+
supported_ops,
276+
min_block_size=min_block_size,
277+
require_full_compilation=require_full_compilation,
278+
)
240279

241280
partitioned_graph = partitioner.partition_graph()
242281

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from torch.fx.graph_module import GraphModule
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
77
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
8-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
8+
from torch_tensorrt.dynamo._defaults import (
9+
DEBUG,
10+
MIN_BLOCK_SIZE,
11+
REQUIRE_FULL_COMPILATION,
12+
)
913
from torch_tensorrt.dynamo.conversion.converter_registry import (
1014
DYNAMO_CONVERTERS as CONVERTERS,
1115
)
@@ -26,6 +30,7 @@ class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc]
2630
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
2731
Generally useful for module-level exclusion ops which are intensive despite being single functions
2832
min_block_size: Minimum number of computational operators per block
33+
require_full_compilation: Require that all computational operators be run in TRT
2934
Returns:
3035
torch.fx.GraphModule
3136
"""
@@ -40,6 +45,7 @@ def __init__(
4045
Collection[str]
4146
] = DEFAULT_SINGLE_NODE_PARTITIONS,
4247
min_block_size: int = MIN_BLOCK_SIZE,
48+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
4349
) -> None:
4450
super().__init__(
4551
graph_module,
@@ -50,12 +56,34 @@ def __init__(
5056
)
5157

5258
self.min_block_size = min_block_size
59+
self.require_full_compilation = require_full_compilation
5360

5461
def propose_partitions(self) -> List[Partition]:
5562
# Propose partitions using the default, then refine the results
5663
initial_proposed_partitions = super().propose_partitions()
5764
partitions = dict(enumerate(initial_proposed_partitions))
5865

66+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
67+
full_support = len(partitions) == 1 and not getattr(
68+
self.operator_support, "unsupported_operators", True
69+
)
70+
71+
if not full_support and self.require_full_compilation:
72+
raise AssertionError(
73+
"require_full_compilation=True was specified, but model is not fully supported"
74+
)
75+
76+
if (
77+
full_support
78+
and self.require_full_compilation
79+
and self.min_block_size != MIN_BLOCK_SIZE
80+
):
81+
logger.warning(
82+
"Detected both require_full_compilation and min_block_size compilation "
83+
"arguments were specified. Disregarding min_block_size argument for "
84+
"fully supported model."
85+
)
86+
5987
# For each partition, determine whether or not the number of computational operators
6088
# exceeds the threshold, and if not, remove that partition
6189
partitions_to_remove = {}
@@ -81,7 +109,11 @@ def propose_partitions(self) -> List[Partition]:
81109
):
82110
compute_node_count += 1
83111

84-
if compute_node_count < self.min_block_size and not exempted_partition:
112+
if (
113+
compute_node_count < self.min_block_size
114+
and not exempted_partition
115+
and not (full_support and self.require_full_compilation)
116+
):
85117
partitions_to_remove[id] = compute_node_count
86118

87119
# Remove any nodes violating the criteria specified by the user
@@ -172,6 +204,7 @@ def partition(
172204
verbose: bool = DEBUG,
173205
min_block_size: int = MIN_BLOCK_SIZE,
174206
torch_executed_ops: Optional[Set[str]] = None,
207+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
175208
) -> torch.fx.GraphModule:
176209
"""Partition an FX GraphModule with aten ops into TRT engines
177210
Partitioning is based on converter operator support
@@ -181,6 +214,7 @@ def partition(
181214
verbose: Bool representing whether to print operator support
182215
min_block_size: Minimum number of operators per TRT-Engine Block
183216
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
217+
require_full_compilation: Whether to require that all operators be run in TRT
184218
Returns:
185219
torch.fx.GraphModule
186220
"""
@@ -189,7 +223,12 @@ def partition(
189223
if torch_executed_ops is not None
190224
else set()
191225
)
192-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
226+
partitioner = TRTPartitioner(
227+
gm,
228+
supported_ops,
229+
min_block_size=min_block_size,
230+
require_full_compilation=require_full_compilation,
231+
)
193232

194233
# Determine partitions based on user specifications and operator support
195234
# Then, fuse partitions and display overview of supported/unsupported operators

py/torch_tensorrt/dynamo/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
184184
# Parse input runtime specification
185185
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
186186

187+
# Ignore and warn about require_full_compilation flag
188+
if settings.require_full_compilation:
189+
logger.warning(
190+
"Detected require_full_compilation=True for a torch.compile run. "
191+
"This option has no effect in torch.compile."
192+
)
193+
settings.require_full_compilation = False
194+
187195
logger.info("Compilation Settings: %s\n", settings)
188196

189197
return settings

tests/py/dynamo/partitioning/test_fast_partitioning.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,52 @@ def forward(self, x, y):
3131
"Single operators should not be segmented",
3232
)
3333

34+
def test_partition_fully_supported_one_op_require_full_compilation(self):
35+
class FullySupportedOneOp(torch.nn.Module):
36+
def __init__(self, *args, **kwargs) -> None:
37+
super().__init__(*args, **kwargs)
38+
39+
def forward(self, x, y):
40+
return torch.ops.aten.add.Tensor(x, y)
41+
42+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
43+
partitioned_graph = partitioning.fast_partition(
44+
deepcopy(fx_graph), require_full_compilation=True
45+
)
46+
self.assertEquals(
47+
len(
48+
[
49+
1
50+
for submod in list(partitioned_graph.named_children())
51+
if "_run_on_acc" in submod[0]
52+
]
53+
),
54+
1,
55+
"Single operators can be segmented if full compilation is required",
56+
)
57+
58+
def test_partition_fully_supported_one_op(self):
59+
class FullySupportedOneOp(torch.nn.Module):
60+
def __init__(self, *args, **kwargs) -> None:
61+
super().__init__(*args, **kwargs)
62+
63+
def forward(self, x, y):
64+
return torch.ops.aten.add.Tensor(x, y)
65+
66+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
67+
partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph))
68+
self.assertEquals(
69+
len(
70+
[
71+
1
72+
for submod in list(partitioned_graph.named_children())
73+
if "_run_on_acc" in submod[0]
74+
]
75+
),
76+
0,
77+
"Single operators should not be segmented",
78+
)
79+
3480
def test_partition_fully_supported_multi_op(self):
3581
class FullySupportedMultiOp(torch.nn.Module):
3682
def __init__(self, *args, **kwargs) -> None:

tests/py/dynamo/partitioning/test_global_partitioning.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ def forward(self, x, y):
2525
"Single operators should not be segmented",
2626
)
2727

28+
def test_partition_fully_supported_one_op_require_full_compilation(self):
29+
class FullySupportedOneOp(torch.nn.Module):
30+
def __init__(self, *args, **kwargs) -> None:
31+
super().__init__(*args, **kwargs)
32+
33+
def forward(self, x, y):
34+
return torch.ops.aten.add.Tensor(x, y)
35+
36+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
37+
partitioned_graph = partitioning.global_partition(
38+
deepcopy(fx_graph), require_full_compilation=True
39+
)
40+
self.assertEquals(
41+
len(list(partitioned_graph.named_children())),
42+
1,
43+
"Single operators can be segmented if full compilation is required",
44+
)
45+
2846
def test_partition_fully_supported_multi_op(self):
2947
class FullySupportedMultiOp(torch.nn.Module):
3048
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)