Skip to content

Commit bbf514f

Browse files
committed
feat: Add feature to toggle partitioner
1 parent 3e948a6 commit bbf514f

File tree

8 files changed

+182
-38
lines changed

8 files changed

+182
-38
lines changed

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
1212
TRUNCATE_LONG_AND_DOUBLE = False
13+
USE_FAST_PARTITIONER = True

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
1414
TRUNCATE_LONG_AND_DOUBLE,
15+
USE_FAST_PARTITIONER,
1516
)
1617

1718

@@ -28,3 +29,4 @@ class CompilationSettings:
2829
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2930
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3031
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
32+
use_fast_partitioner: bool = USE_FAST_PARTITIONER

py/torch_tensorrt/dynamo/backend/backends.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
1212
pre_aot_substitutions,
1313
)
14-
from torch_tensorrt.dynamo.partitioning import (
15-
partition,
16-
get_submod_inputs,
17-
)
14+
from torch_tensorrt.dynamo import partitioning
1815
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1916
from torch_tensorrt.dynamo.conversion import (
2017
convert_module,
@@ -118,12 +115,20 @@ def _compile_module(
118115
Compiled FX GraphModule
119116
"""
120117
# Partition module into components that can be TRT-accelerated
121-
partitioned_module = partition(
122-
gm,
123-
verbose=settings.debug,
124-
min_block_size=settings.min_block_size,
125-
torch_executed_ops=settings.torch_executed_ops,
126-
)
118+
if settings.use_fast_partitioner:
119+
partitioned_module = partitioning.fast_partition(
120+
gm,
121+
verbose=settings.debug,
122+
min_block_size=settings.min_block_size,
123+
torch_executed_ops=settings.torch_executed_ops,
124+
)
125+
else:
126+
partitioned_module = partitioning.global_partition(
127+
gm,
128+
verbose=settings.debug,
129+
min_block_size=settings.min_block_size,
130+
torch_executed_ops=settings.torch_executed_ops,
131+
)
127132

128133
# Store TRT replicas of Torch subgraphs
129134
trt_modules = {}
@@ -133,13 +138,13 @@ def _compile_module(
133138
for name, _ in partitioned_module.named_children():
134139

135140
# Criteria for a module to be convertible to TRT
136-
if "_run_on_acc" not in name:
141+
if settings.use_fast_partitioner and "_run_on_acc" not in name:
137142
continue
138143

139144
submodule = getattr(partitioned_module, name)
140145

141146
# Get submodule inputs
142-
submodule_inputs = get_submod_inputs(
147+
submodule_inputs = partitioning.get_submod_inputs(
143148
partitioned_module, submodule, sample_inputs
144149
)
145150

py/torch_tensorrt/dynamo/compile.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
3333
TRUNCATE_LONG_AND_DOUBLE,
34+
USE_FAST_PARTITIONER,
3435
)
3536

3637

@@ -64,6 +65,7 @@ def compile(
6465
version_compatible=VERSION_COMPATIBLE,
6566
optimization_level=OPTIMIZATION_LEVEL,
6667
use_python_runtime=USE_PYTHON_RUNTIME,
68+
use_fast_partitioner=USE_FAST_PARTITIONER,
6769
**kwargs,
6870
):
6971
if debug:
@@ -73,7 +75,7 @@ def compile(
7375
"The Dynamo backend is an experimental feature, for which only the "
7476
+ "following arguments are supported: "
7577
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
76-
+ "torch_executed_ops, pass_through_build_failures}"
78+
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
7779
)
7880

7981
if not isinstance(inputs, collections.abc.Sequence):
@@ -111,17 +113,13 @@ def compile(
111113
"optimization_level": optimization_level,
112114
"use_python_runtime": use_python_runtime,
113115
"truncate_long_and_double": truncate_long_and_double,
116+
"use_fast_partitioner": use_fast_partitioner,
114117
}
115118

116119
settings = CompilationSettings(**compilation_options)
117-
if kwargs.get("use_capability_partitioner", None):
118-
model = lower_model(gm, torch_inputs)
119-
return _compile_module(model, torch_inputs, settings)
120-
else:
121-
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
122-
trt_module = _compile_graph(split_result, torch_inputs, settings)
123120

124-
return trt_module
121+
model = lower_model(gm, torch_inputs)
122+
return _compile_module(model, torch_inputs, settings)
125123

126124

127125
def _compile_graph(
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .common import get_submod_inputs
2-
from ._adjacency_partitioner import (
3-
partition,
4-
)
2+
from ._adjacency_partitioner import partition as fast_partition
3+
from ._global_partitioner import partition as global_partition

tests/py/dynamo/backend/test_backend_compiler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch_tensorrt
3-
from torch_tensorrt.dynamo.partitioning import partition
3+
from torch_tensorrt.dynamo.partitioning import fast_partition
44
from torch.testing._internal.common_utils import run_tests, TestCase
55
from copy import deepcopy
66
from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT
@@ -17,7 +17,7 @@ def forward(self, x, y):
1717
return torch.mean(out, dim=1)
1818

1919
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
20-
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
20+
partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3)
2121

2222
self.assertEquals(
2323
len(
@@ -192,7 +192,7 @@ def forward(self, x, y):
192192
)
193193

194194
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
195-
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
195+
partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3)
196196

197197
self.assertEquals(
198198
len(list(partitioned_graph.named_children())),

tests/py/dynamo/backend/test_partitioning.py

+129-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from torch_tensorrt.dynamo.partitioning import partition
1+
from torch_tensorrt.dynamo import partitioning
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
from utils import lower_graph_testing
44
import torch
55
from copy import deepcopy
66
import numpy as np
77

88

9-
class TestPartitioning(TestCase):
9+
class TestFastPartitioning(TestCase):
1010
def test_partition_fully_supported_one_op(self):
1111
class FullySupportedOneOp(torch.nn.Module):
1212
def __init__(self, *args, **kwargs) -> None:
@@ -16,7 +16,7 @@ def forward(self, x, y):
1616
return torch.ops.aten.add.Tensor(x, y)
1717

1818
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
19-
partitioned_graph = partition(deepcopy(fx_graph))
19+
partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph))
2020
self.assertEquals(
2121
len(
2222
[
@@ -42,7 +42,9 @@ def forward(self, x, y):
4242
return pow_
4343

4444
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
45-
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2)
45+
partitioned_graph = partitioning.fast_partition(
46+
deepcopy(fx_graph), min_block_size=2
47+
)
4648
self.assertEquals(
4749
len(
4850
[
@@ -69,7 +71,9 @@ def forward(self, x, y):
6971
return pow_
7072

7173
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
72-
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2)
74+
partitioned_graph = partitioning.fast_partition(
75+
deepcopy(fx_graph), min_block_size=2
76+
)
7377
self.assertEquals(
7478
len(
7579
[
@@ -118,6 +122,7 @@ def forward(self, x, y):
118122
min_block_size=2,
119123
torch_executed_ops={"torch.ops.aten.add.Tensor"},
120124
testing_partitioning=True,
125+
use_fast_partitioner=True,
121126
)
122127

123128
self.assertEquals(
@@ -144,5 +149,124 @@ def forward(self, x, y):
144149
)
145150

146151

152+
class TestGlobalPartitioning(TestCase):
153+
def test_partition_fully_supported_one_op(self):
154+
class FullySupportedOneOp(torch.nn.Module):
155+
def __init__(self, *args, **kwargs) -> None:
156+
super().__init__(*args, **kwargs)
157+
158+
def forward(self, x, y):
159+
return torch.ops.aten.add.Tensor(x, y)
160+
161+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
162+
partitioned_graph = partitioning.global_partition(deepcopy(fx_graph))
163+
self.assertEquals(
164+
len(list(partitioned_graph.named_children())),
165+
0,
166+
"Single operators should not be segmented",
167+
)
168+
169+
def test_partition_fully_supported_multi_op(self):
170+
class FullySupportedMultiOp(torch.nn.Module):
171+
def __init__(self, *args, **kwargs) -> None:
172+
super().__init__(*args, **kwargs)
173+
174+
def forward(self, x, y):
175+
sum_ = torch.ops.aten.sub.Tensor(x, y)
176+
concat_ = torch.ops.aten.cat.default(x, sum_)
177+
relu_ = torch.ops.aten.relu.default(concat_)
178+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
179+
return pow_
180+
181+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
182+
partitioned_graph = partitioning.global_partition(
183+
deepcopy(fx_graph), min_block_size=2
184+
)
185+
self.assertEquals(
186+
len(list(partitioned_graph.named_children())),
187+
1,
188+
"All operators are supported, there should be one segment",
189+
)
190+
191+
def test_partition_partially_supported_multi_op(self):
192+
class PartiallySupportedMultiOp(torch.nn.Module):
193+
def __init__(self, *args, **kwargs) -> None:
194+
super().__init__(*args, **kwargs)
195+
196+
def forward(self, x, y):
197+
sum_1 = torch.ops.aten.add.Tensor(x, y)
198+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
199+
sum_ = np.sum(sum_1) + np.sum(sum_2)
200+
relu_ = torch.ops.aten.relu.default(sum_)
201+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
202+
return pow_
203+
204+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
205+
partitioned_graph = partitioning.global_partition(
206+
deepcopy(fx_graph), min_block_size=2
207+
)
208+
self.assertEquals(
209+
len(list(partitioned_graph.named_children())),
210+
2,
211+
"Unsupported operators interleave supported ones, expected 2 segments",
212+
)
213+
214+
def test_partition_partially_supported_with_torch_executed_ops(self):
215+
class PartiallySupportedMultiOp(torch.nn.Module):
216+
def __init__(self, *args, **kwargs) -> None:
217+
super().__init__(*args, **kwargs)
218+
219+
def forward(self, x, y):
220+
sum_1 = torch.ops.aten.add.Tensor(x, y)
221+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
222+
sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2)
223+
relu_ = torch.ops.aten.relu.default(sum_)
224+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
225+
return pow_
226+
227+
unexpected_ops = {torch.ops.aten.add.Tensor}
228+
229+
inputs = [
230+
torch.randint(
231+
1,
232+
10,
233+
(5,),
234+
),
235+
torch.randint(
236+
1,
237+
10,
238+
(5,),
239+
),
240+
]
241+
242+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
243+
(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
244+
fx_graph,
245+
inputs,
246+
unexpected_ops=unexpected_ops,
247+
min_block_size=2,
248+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
249+
testing_partitioning=True,
250+
use_fast_partitioner=False,
251+
)
252+
253+
self.assertEquals(
254+
len(unexpected_ops_seen),
255+
0,
256+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
257+
)
258+
259+
self.assertEquals(
260+
len(partitioned_graphs),
261+
1,
262+
"Without control flow breaks, there should only be a single graph",
263+
)
264+
self.assertEquals(
265+
len(list(partitioned_graphs[0].named_children())),
266+
1,
267+
"Certain operators are set to run in Torch, expected 1 segment",
268+
)
269+
270+
147271
if __name__ == "__main__":
148272
run_tests()

0 commit comments

Comments
 (0)