14
14
import torch
15
15
from executorch .backends .cadence .aot import compiler
16
16
from executorch .backends .cadence .aot .fuse_ops import (
17
+ FuseCascadedViewOps ,
17
18
FuseFullThenReshapePass ,
19
+ FuseMMWithAdd ,
18
20
FuseMulScalarIntoDequantPass ,
19
21
FuseMulTensorIntoDequantPass ,
20
22
FuseQuantDequantToRequantizePass ,
@@ -39,113 +41,133 @@ def check_op_counts(
39
41
40
42
41
43
class TestFusionPasses (TestFusionPassesBase ):
42
- def test_addmm_fusion (self ):
43
- class AddmmFeasible1 (torch .nn .Module ):
44
- def forward (self , x , y , z ):
45
- t1 = torch .mm (x , y )
46
- return torch .add (t1 , z )
47
-
48
- x = torch .randn (3 , 5 )
49
- y = torch .randn (5 , 6 )
50
- z = torch .randn (6 )
51
-
52
- graph_module = (
53
- compiler .export_to_cadence (AddmmFeasible1 (), (x , y , z ))
54
- .exported_program ()
55
- .graph_module
44
+ def test_fuse_mm_with_add (self ):
45
+ builder = GraphBuilder ()
46
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
47
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
48
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
49
+ mm = builder .call_operator (
50
+ op = exir_ops .edge .aten .mm .default ,
51
+ args = (x , y ),
52
+ )
53
+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
54
+ builder .output ([output ])
55
+ original_graph = builder .get_graph_module ()
56
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
57
+ converted_graph .graph .eliminate_dead_code ()
58
+ self .assertEqual (
59
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
56
60
)
57
- graph_module .graph .eliminate_dead_code ()
58
-
59
- # Assert that mm and add were fused to addmm
60
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
61
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
62
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
63
-
64
- class AddmmFeasible2 (torch .nn .Module ):
65
- def forward (self , x , y , z ):
66
- t1 = y .view ((8 , 6 ))
67
- t2 = torch .mm (x , t1 )
68
- t3 = t2 .view ((2 , 2 , 6 ))
69
- return torch .add (t3 , z )
70
-
71
- x = torch .randn (4 , 8 )
72
- y = torch .randn (2 , 4 , 6 )
73
- z = torch .randn (6 )
61
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
62
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
74
63
75
- graph_module = (
76
- compiler .export_to_cadence (AddmmFeasible2 (), (x , y , z ))
77
- .exported_program ()
78
- .graph_module
64
+ def test_fuse_view_mm_view_add (self ):
65
+ builder = GraphBuilder ()
66
+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
67
+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
68
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
69
+ y_view = builder .call_operator (
70
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
79
71
)
80
- graph_module .graph .eliminate_dead_code ()
81
- # Assert that mm and add were fused to addmm
82
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
83
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
84
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
85
-
86
- # Bias is a singleton value, broadcastable to output of mm
87
- class AddmmFeasible3 (torch .nn .Module ):
88
- def forward (self , x , y ):
89
- t1 = torch .mm (x , y )
90
- return torch .add (t1 , torch .ones (1 ))
91
-
92
- x = torch .randn (3 , 5 )
93
- y = torch .randn (5 , 6 )
94
-
95
- graph_module = (
96
- compiler .export_to_cadence (AddmmFeasible3 (), (x , y ))
97
- .exported_program ()
98
- .graph_module
72
+ mm = builder .call_operator (
73
+ op = exir_ops .edge .aten .mm .default ,
74
+ args = (x , y_view ),
75
+ )
76
+ mm_view = builder .call_operator (
77
+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
99
78
)
100
- graph_module .graph .eliminate_dead_code ()
101
- # Assert that mm and add were fused to addmm
102
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .addmm .default ), 1 )
103
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .mm .default ), 0 )
104
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 0 )
79
+ output = builder .call_operator (
80
+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
81
+ )
82
+ builder .output ([output ])
83
+ original_graph = builder .get_graph_module ()
84
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
85
+ converted_graph .graph .eliminate_dead_code ()
86
+ self .assertEqual (
87
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
88
+ )
89
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
90
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
105
91
92
+ def test_keep_view_mm_view_add (self ):
93
+ builder = GraphBuilder ()
94
+ x = builder .placeholder ("x" , torch .randn (4 , 8 , dtype = torch .float32 ))
95
+ y = builder .placeholder ("y" , torch .randn (2 , 4 , 6 , dtype = torch .float32 ))
106
96
# Bias is not broadcastable to output of mm
107
- class AddmmInfeasible1 (torch .nn .Module ):
108
- def forward (self , x , y , z ):
109
- t1 = y .view ((8 , 6 ))
110
- t2 = torch .mm (x , t1 )
111
- t3 = t2 .view ((2 , 2 , 6 ))
112
- return torch .add (t3 , z )
113
-
114
- x = torch .randn (4 , 8 )
115
- y = torch .randn (2 , 4 , 6 )
116
- z = torch .randn (2 , 2 , 1 )
117
-
118
- graph_module = (
119
- compiler .export_to_cadence (AddmmInfeasible1 (), (x , y , z ))
120
- .exported_program ()
121
- .graph_module
97
+ z = builder .placeholder ("z" , torch .randn (2 , 2 , 1 , dtype = torch .float32 ))
98
+ y_view = builder .call_operator (
99
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [8 , 6 ])
100
+ )
101
+ mm = builder .call_operator (
102
+ op = exir_ops .edge .aten .mm .default ,
103
+ args = (x , y_view ),
122
104
)
123
- graph_module .graph .eliminate_dead_code ()
105
+ mm_view = builder .call_operator (
106
+ op = exir_ops .edge .aten .view_copy .default , args = (mm , [2 , 2 , 6 ])
107
+ )
108
+ output = builder .call_operator (
109
+ op = exir_ops .edge .aten .add .Tensor , args = (mm_view , z )
110
+ )
111
+ builder .output ([output ])
112
+ original_graph = builder .get_graph_module ()
113
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
114
+ converted_graph .graph .eliminate_dead_code ()
124
115
# Assert that mm and add were not fused to addmm, since z cannot be
125
116
# broadcasted to the out of mm.
126
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 1 )
127
-
128
- # The add consuming the output of mm has more than one users.
129
- class AddmmInfeasible2 (torch .nn .Module ):
130
- def forward (self , x , y , z ):
131
- t1 = torch .mm (x , y )
132
- t2 = torch .add (t1 , z )
133
- t3 = torch .add (t2 , z )
134
- return torch .add (t2 , t3 )
117
+ self .assertEqual (
118
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
119
+ )
120
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
121
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
135
122
136
- x = torch .randn (3 , 5 )
137
- y = torch .randn (5 , 6 )
138
- z = torch .randn (6 )
123
+ def test_fuse_mm_add_with_bias (self ):
124
+ builder = GraphBuilder ()
125
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
126
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
127
+ mm = builder .call_operator (
128
+ op = exir_ops .edge .aten .mm .default ,
129
+ args = (x , y ),
130
+ )
131
+ bias = builder .call_operator (op = exir_ops .edge .aten .full .default , args = ([1 ], 1 ))
132
+ output = builder .call_operator (
133
+ op = exir_ops .edge .aten .add .Tensor , args = (mm , bias )
134
+ )
135
+ builder .output ([output ])
136
+ original_graph = builder .get_graph_module ()
137
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
138
+ converted_graph .graph .eliminate_dead_code ()
139
+ self .assertEqual (
140
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 1
141
+ )
142
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 0 )
143
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 0 )
139
144
140
- graph_module = (
141
- compiler .export_to_cadence (AddmmInfeasible2 (), (x , y , z ))
142
- .exported_program ()
143
- .graph_module
145
+ def test_keep_mm_add_with_multiple_users (self ):
146
+ builder = GraphBuilder ()
147
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
148
+ y = builder .placeholder ("y" , torch .randn (5 , 6 , dtype = torch .float32 ))
149
+ z = builder .placeholder ("z" , torch .randn (6 , dtype = torch .float32 ))
150
+ mm = builder .call_operator (
151
+ op = exir_ops .edge .aten .mm .default ,
152
+ args = (x , y ),
153
+ )
154
+ # The add consuming the output of mm has more than one users.
155
+ add1 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
156
+ add2 = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (add1 , z ))
157
+ output = builder .call_operator (
158
+ op = exir_ops .edge .aten .add .Tensor , args = (add1 , add2 )
144
159
)
145
- graph_module .graph .eliminate_dead_code ()
160
+ builder .output ([output ])
161
+ original_graph = builder .get_graph_module ()
162
+ converted_graph = FuseMMWithAdd ()(original_graph ).graph_module
163
+ converted_graph .graph .eliminate_dead_code ()
146
164
# Assert that mm and add were not fused to addmm, since add has multiple
147
165
# users.
148
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .add .Tensor ), 3 )
166
+ self .assertEqual (
167
+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
168
+ )
169
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
170
+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
149
171
150
172
# TODO(matthiascremon): enable that pass with new flow
151
173
@torch .no_grad ()
@@ -184,63 +206,70 @@ def forward(self, x):
184
206
)
185
207
186
208
def test_permute_transpose_fusion (self ):
187
- class PermuteTranspose (torch .nn .Module ):
188
- def forward (self , x ):
189
- y = x .permute ((0 , 2 , 4 , 1 , 3 ))
190
- return y .transpose (0 , 1 )
191
-
192
- x = torch .randn (3 , 1 , 3 , 1 , 4 )
193
- graph_module = (
194
- compiler .export_to_cadence (PermuteTranspose (), (x ,))
195
- .exported_program ()
196
- .graph_module
209
+ builder = GraphBuilder ()
210
+ x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
211
+ permute = builder .call_operator (
212
+ op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 4 , 1 , 3 ])
213
+ )
214
+ output = builder .call_operator (
215
+ op = exir_ops .edge .aten .transpose_copy .int ,
216
+ args = (permute , 0 , 1 ),
197
217
)
198
- graph_module .graph .eliminate_dead_code ()
218
+ builder .output (output )
219
+ original_graph = builder .get_graph_module ()
220
+ # Question: This pass can not be applied because [0, 2, 4] != [2, 0, 4] in can_fuse_for_chain. Do I use the right pass?
221
+ converted_graph = FuseTransposeOrPermuteOpPairsPass ()(
222
+ original_graph
223
+ ).graph_module
224
+ converted_graph .graph .eliminate_dead_code ()
199
225
# Assert that permute op was fused with transpose op
200
226
self .assertEqual (
201
- count_node (graph_module , exir_ops .edge .aten .permute_copy .default ), 1
227
+ count_node (converted_graph , exir_ops .edge .aten .permute_copy .default ), 1
202
228
)
203
229
self .assertEqual (
204
- count_node (graph_module , exir_ops .edge .aten .transpose_copy .int ), 0
230
+ count_node (converted_graph , exir_ops .edge .aten .transpose_copy .int ), 0
205
231
)
206
232
207
233
def test_view_fusion (self ):
208
- class ViewFusion (torch .nn .Module ):
209
- def forward (self , x ):
210
- x = x .view ([1 , 8 , 15 ])
211
- x = x .view ([1 , 1 , 120 ])
212
- return x .view ([1 , 12 , 10 ])
213
-
214
- x = torch .randn (8 , 5 , 3 )
215
- graph_module = (
216
- compiler .export_to_cadence (ViewFusion (), (x ,))
217
- .exported_program ()
218
- .graph_module
234
+ builder = GraphBuilder ()
235
+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
236
+ view1 = builder .call_operator (
237
+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
238
+ )
239
+ view2 = builder .call_operator (
240
+ op = exir_ops .edge .aten .view_copy .default , args = (view1 , [1 , 1 , 120 ])
241
+ )
242
+ output = builder .call_operator (
243
+ op = exir_ops .edge .aten .view_copy .default , args = (view2 , [1 , 12 , 10 ])
219
244
)
220
- graph_module .graph .eliminate_dead_code ()
245
+ builder .output (output )
246
+ original_graph = builder .get_graph_module ()
247
+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
248
+ converted_graph .graph .eliminate_dead_code ()
221
249
# Assert that only one view op remains
222
250
self .assertEqual (
223
- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 1
251
+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 1
224
252
)
225
253
226
254
def test_view_fusion_branched (self ):
227
- class ViewFusion (torch .nn .Module ):
228
- def forward (self , x ):
229
- y = x .view ([1 , 8 , 15 ])
230
- z = y .view ([1 , 1 , 120 ])
231
- t = y .view ([120 , 1 , 1 ])
232
- return z , t
233
-
234
- x = torch .randn (8 , 5 , 3 )
235
- graph_module = (
236
- compiler .export_to_cadence (ViewFusion (), (x ,))
237
- .exported_program ()
238
- .graph_module
255
+ builder = GraphBuilder ()
256
+ x = builder .placeholder ("x" , torch .randn (8 , 5 , 3 , dtype = torch .float32 ))
257
+ y = builder .call_operator (
258
+ op = exir_ops .edge .aten .view_copy .default , args = (x , [1 , 8 , 15 ])
259
+ )
260
+ z = builder .call_operator (
261
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [1 , 1 , 120 ])
239
262
)
240
- graph_module .graph .eliminate_dead_code ()
263
+ t = builder .call_operator (
264
+ op = exir_ops .edge .aten .view_copy .default , args = (y , [120 , 1 , 1 ])
265
+ )
266
+ builder .output ([z , t ])
267
+ original_graph = builder .get_graph_module ()
268
+ converted_graph = FuseCascadedViewOps ()(original_graph ).graph_module
269
+ converted_graph .graph .eliminate_dead_code ()
241
270
# z and t should be fused and y should be eliminated.
242
271
self .assertEqual (
243
- count_node (graph_module , exir_ops .edge .aten .view_copy .default ), 2
272
+ count_node (converted_graph , exir_ops .edge .aten .view_copy .default ), 2
244
273
)
245
274
246
275
def test_force_quant_dequant_fusion (self ):
0 commit comments