1
- from torch_tensorrt .dynamo . partitioning import partition
1
+ from torch_tensorrt .dynamo import partitioning
2
2
from torch .testing ._internal .common_utils import run_tests , TestCase
3
3
from utils import lower_graph_testing
4
4
import torch
5
5
from copy import deepcopy
6
6
import numpy as np
7
7
8
8
9
- class TestPartitioning (TestCase ):
9
+ class TestFastPartitioning (TestCase ):
10
10
def test_partition_fully_supported_one_op (self ):
11
11
class FullySupportedOneOp (torch .nn .Module ):
12
12
def __init__ (self , * args , ** kwargs ) -> None :
@@ -16,7 +16,7 @@ def forward(self, x, y):
16
16
return torch .ops .aten .add .Tensor (x , y )
17
17
18
18
fx_graph = torch .fx .symbolic_trace (FullySupportedOneOp ())
19
- partitioned_graph = partition (deepcopy (fx_graph ))
19
+ partitioned_graph = partitioning . fast_partition (deepcopy (fx_graph ))
20
20
self .assertEquals (
21
21
len (
22
22
[
@@ -42,7 +42,9 @@ def forward(self, x, y):
42
42
return pow_
43
43
44
44
fx_graph = torch .fx .symbolic_trace (FullySupportedMultiOp ())
45
- partitioned_graph = partition (deepcopy (fx_graph ), min_block_size = 2 )
45
+ partitioned_graph = partitioning .fast_partition (
46
+ deepcopy (fx_graph ), min_block_size = 2
47
+ )
46
48
self .assertEquals (
47
49
len (
48
50
[
@@ -69,7 +71,9 @@ def forward(self, x, y):
69
71
return pow_
70
72
71
73
fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
72
- partitioned_graph = partition (deepcopy (fx_graph ), min_block_size = 2 )
74
+ partitioned_graph = partitioning .fast_partition (
75
+ deepcopy (fx_graph ), min_block_size = 2
76
+ )
73
77
self .assertEquals (
74
78
len (
75
79
[
@@ -118,6 +122,7 @@ def forward(self, x, y):
118
122
min_block_size = 2 ,
119
123
torch_executed_ops = {"torch.ops.aten.add.Tensor" },
120
124
testing_partitioning = True ,
125
+ use_fast_partitioner = True ,
121
126
)
122
127
123
128
self .assertEquals (
@@ -144,5 +149,124 @@ def forward(self, x, y):
144
149
)
145
150
146
151
152
+ class TestGlobalPartitioning (TestCase ):
153
+ def test_partition_fully_supported_one_op (self ):
154
+ class FullySupportedOneOp (torch .nn .Module ):
155
+ def __init__ (self , * args , ** kwargs ) -> None :
156
+ super ().__init__ (* args , ** kwargs )
157
+
158
+ def forward (self , x , y ):
159
+ return torch .ops .aten .add .Tensor (x , y )
160
+
161
+ fx_graph = torch .fx .symbolic_trace (FullySupportedOneOp ())
162
+ partitioned_graph = partitioning .global_partition (deepcopy (fx_graph ))
163
+ self .assertEquals (
164
+ len (list (partitioned_graph .named_children ())),
165
+ 0 ,
166
+ "Single operators should not be segmented" ,
167
+ )
168
+
169
+ def test_partition_fully_supported_multi_op (self ):
170
+ class FullySupportedMultiOp (torch .nn .Module ):
171
+ def __init__ (self , * args , ** kwargs ) -> None :
172
+ super ().__init__ (* args , ** kwargs )
173
+
174
+ def forward (self , x , y ):
175
+ sum_ = torch .ops .aten .sub .Tensor (x , y )
176
+ concat_ = torch .ops .aten .cat .default (x , sum_ )
177
+ relu_ = torch .ops .aten .relu .default (concat_ )
178
+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
179
+ return pow_
180
+
181
+ fx_graph = torch .fx .symbolic_trace (FullySupportedMultiOp ())
182
+ partitioned_graph = partitioning .global_partition (
183
+ deepcopy (fx_graph ), min_block_size = 2
184
+ )
185
+ self .assertEquals (
186
+ len (list (partitioned_graph .named_children ())),
187
+ 1 ,
188
+ "All operators are supported, there should be one segment" ,
189
+ )
190
+
191
+ def test_partition_partially_supported_multi_op (self ):
192
+ class PartiallySupportedMultiOp (torch .nn .Module ):
193
+ def __init__ (self , * args , ** kwargs ) -> None :
194
+ super ().__init__ (* args , ** kwargs )
195
+
196
+ def forward (self , x , y ):
197
+ sum_1 = torch .ops .aten .add .Tensor (x , y )
198
+ sum_2 = torch .ops .aten .add .Tensor (x , sum_1 )
199
+ sum_ = np .sum (sum_1 ) + np .sum (sum_2 )
200
+ relu_ = torch .ops .aten .relu .default (sum_ )
201
+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
202
+ return pow_
203
+
204
+ fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
205
+ partitioned_graph = partitioning .global_partition (
206
+ deepcopy (fx_graph ), min_block_size = 2
207
+ )
208
+ self .assertEquals (
209
+ len (list (partitioned_graph .named_children ())),
210
+ 2 ,
211
+ "Unsupported operators interleave supported ones, expected 2 segments" ,
212
+ )
213
+
214
+ def test_partition_partially_supported_with_torch_executed_ops (self ):
215
+ class PartiallySupportedMultiOp (torch .nn .Module ):
216
+ def __init__ (self , * args , ** kwargs ) -> None :
217
+ super ().__init__ (* args , ** kwargs )
218
+
219
+ def forward (self , x , y ):
220
+ sum_1 = torch .ops .aten .add .Tensor (x , y )
221
+ sum_2 = torch .ops .aten .add .Tensor (x , sum_1 )
222
+ sum_ = torch .ops .aten .add .Tensor (sum_1 , sum_2 )
223
+ relu_ = torch .ops .aten .relu .default (sum_ )
224
+ pow_ = torch .ops .aten .pow .Tensor_Scalar (relu_ , 2 )
225
+ return pow_
226
+
227
+ unexpected_ops = {torch .ops .aten .add .Tensor }
228
+
229
+ inputs = [
230
+ torch .randint (
231
+ 1 ,
232
+ 10 ,
233
+ (5 ,),
234
+ ),
235
+ torch .randint (
236
+ 1 ,
237
+ 10 ,
238
+ (5 ,),
239
+ ),
240
+ ]
241
+
242
+ fx_graph = torch .fx .symbolic_trace (PartiallySupportedMultiOp ())
243
+ (unexpected_ops_seen , _ , partitioned_graphs ,) = lower_graph_testing (
244
+ fx_graph ,
245
+ inputs ,
246
+ unexpected_ops = unexpected_ops ,
247
+ min_block_size = 2 ,
248
+ torch_executed_ops = {"torch.ops.aten.add.Tensor" },
249
+ testing_partitioning = True ,
250
+ use_fast_partitioner = False ,
251
+ )
252
+
253
+ self .assertEquals (
254
+ len (unexpected_ops_seen ),
255
+ 0 ,
256
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
257
+ )
258
+
259
+ self .assertEquals (
260
+ len (partitioned_graphs ),
261
+ 1 ,
262
+ "Without control flow breaks, there should only be a single graph" ,
263
+ )
264
+ self .assertEquals (
265
+ len (list (partitioned_graphs [0 ].named_children ())),
266
+ 1 ,
267
+ "Certain operators are set to run in Torch, expected 1 segment" ,
268
+ )
269
+
270
+
147
271
if __name__ == "__main__" :
148
272
run_tests ()
0 commit comments