13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
15
15
import torch .nn as nn
16
- import torch .nn .functional as F
17
- from executorch .backends .cadence .aot import compiler
18
16
from executorch .backends .cadence .aot .compiler import export_to_edge
19
17
from executorch .backends .cadence .aot .fuse_ops import FuseQuantDequantToRequantizePass
20
18
from executorch .backends .cadence .aot .graph_builder import GraphBuilder
@@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase):
53
51
)
54
52
@torch .no_grad ()
55
53
def test_remove_to_ops (self , shape : Tuple [int ]):
56
- class M (torch .nn .Module ):
57
- def forward (self , x : torch .Tensor ):
58
- return exir_ops .edge .aten .to (x , dtype = torch .float32 )
59
-
60
- model = M ()
61
- x = torch .randn (shape )
62
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
63
- p = RemoveToOpsPass ()
64
-
65
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
54
+ builder = GraphBuilder ()
55
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
56
+ x = builder .call_operator (
57
+ op = exir_ops .edge .aten .to .dtype ,
58
+ args = (x , torch .float32 ),
59
+ )
60
+ builder .output ([x ])
61
+ original = builder .get_graph_module ()
62
+ graph_after_passes = cast (PassResult , RemoveToOpsPass ()(original )).graph_module
66
63
67
64
self .assertEqual (
68
65
count_node (graph_after_passes , exir_ops .edge .aten .to .dtype ),
@@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor):
83
80
)
84
81
@torch .no_grad ()
85
82
def test_remove_nop_add_op_pass (self , shape : Tuple [int ]):
86
- class FullX (torch .nn .Module ):
87
- def forward (self , t : torch .Tensor ):
88
- return torch .add (torch .full (shape , 0 ), t )
89
-
90
- class FullY (torch .nn .Module ):
91
- def forward (self , t : torch .Tensor ):
92
- return torch .add (t , torch .full (shape , 0 ))
93
-
94
- model = FullX ()
95
- t = torch .full (shape , 3 )
96
- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
97
-
98
- p = RemoveNopAddOpPass ()
99
-
100
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
101
- self .assertEqual (
102
- count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
103
- 0 ,
104
- )
105
-
106
- model = FullY ()
107
- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
108
-
109
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
110
-
83
+ builder = GraphBuilder ()
84
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
85
+ zeros = builder .call_operator (
86
+ op = exir_ops .edge .aten .full .default , args = (shape , 0 )
87
+ )
88
+ left_add = builder .call_operator (
89
+ op = exir_ops .edge .aten .add .Tensor ,
90
+ args = (zeros , x ),
91
+ )
92
+ right_add = builder .call_operator (
93
+ op = exir_ops .edge .aten .add .Tensor ,
94
+ args = (left_add , zeros ),
95
+ )
96
+ builder .output ([right_add ])
97
+ original = builder .get_graph_module ()
98
+ graph_after_passes = cast (
99
+ PassResult , RemoveNopAddOpPass ()(original )
100
+ ).graph_module
111
101
self .assertEqual (
112
102
count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
113
103
0 ,
@@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor):
122
112
)
123
113
@torch .no_grad ()
124
114
def test_remove_nop_mul_op_pass (self , shape : Tuple [int ]):
125
- class FullX (torch .nn .Module ):
126
- def forward (self , t : torch .Tensor ):
127
- return torch .mul (torch .full (shape , 0 ), t )
128
-
129
- class FullY (torch .nn .Module ):
130
- def forward (self , t : torch .Tensor ):
131
- return torch .mul (t , torch .full (shape , 0 ))
132
-
133
- model = FullX ()
134
- t = torch .full (shape , 3 )
135
- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
136
-
137
- p = RemoveNopMulOpPass ()
138
-
139
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
140
- self .assertEqual (
141
- count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
142
- 0 ,
143
- )
144
-
145
- model = FullY ()
146
- graph_module = export_to_edge (model , (t ,)).exported_program ().graph_module
147
-
148
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
149
-
115
+ builder = GraphBuilder ()
116
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
117
+ zeros = builder .call_operator (
118
+ op = exir_ops .edge .aten .full .default , args = (shape , 0 )
119
+ )
120
+ left_mul = builder .call_operator (
121
+ op = exir_ops .edge .aten .mul .Tensor ,
122
+ args = (zeros , x ),
123
+ )
124
+ right_mul = builder .call_operator (
125
+ op = exir_ops .edge .aten .mul .Tensor ,
126
+ args = (left_mul , zeros ),
127
+ )
128
+ builder .output ([right_mul ])
129
+ original = builder .get_graph_module ()
130
+ graph_after_passes = cast (
131
+ PassResult , RemoveNopMulOpPass ()(original )
132
+ ).graph_module
150
133
self .assertEqual (
151
134
count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
152
135
0 ,
@@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor):
159
142
)
160
143
@torch .no_grad ()
161
144
def test_remove_alias_copy (self , shape : Tuple [int ]):
162
- class M (torch .nn .Module ):
163
- def forward (self , x : torch .Tensor ):
164
- return exir_ops .edge .aten .alias_copy (x )
165
-
166
- model = M ()
167
- x = torch .randn (shape )
168
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
169
-
170
- p = RemoveAliasCopyOpPass ()
171
-
172
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
173
-
145
+ builder = GraphBuilder ()
146
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
147
+ alias = builder .call_operator (
148
+ op = exir_ops .edge .aten .alias_copy .default , args = (x ,)
149
+ )
150
+ builder .output ([alias ])
151
+ original = builder .get_graph_module ()
152
+ graph_after_passes = cast (
153
+ PassResult , RemoveAliasCopyOpPass ()(original )
154
+ ).graph_module
174
155
self .assertEqual (
175
156
count_node (graph_after_passes , exir_ops .edge .aten .alias_copy .default ),
176
157
0 ,
@@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor):
183
164
)
184
165
@torch .no_grad ()
185
166
def test_remove_detach_copy (self , shape : Tuple [int ]):
186
- # aten::detach is converted to aten::alias_copy after functionalization & decomposition.
187
- class M (torch .nn .Module ):
188
- def forward (self , x : torch .Tensor ):
189
- return exir_ops .edge .aten .detach_copy (x )
190
-
191
- model = M ()
192
- x = torch .randn (shape )
193
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
194
-
195
- p = RemoveDetachCopyPass ()
196
-
197
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
198
-
167
+ builder = GraphBuilder ()
168
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
169
+ detach = builder .call_operator (
170
+ op = exir_ops .edge .aten .detach_copy .default , args = (x ,)
171
+ )
172
+ builder .output ([detach ])
173
+ original = builder .get_graph_module ()
174
+ graph_after_passes = cast (
175
+ PassResult , RemoveDetachCopyPass ()(original )
176
+ ).graph_module
199
177
self .assertEqual (
200
178
count_node (graph_after_passes , exir_ops .edge .aten .detach_copy .default ),
201
179
0 ,
@@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor):
210
188
def test_remove_zero_sized_constant_pad_nd (
211
189
self , shape : Tuple [int ], padding : Tuple [int ]
212
190
):
213
- # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
214
- class Padding (torch .nn .Module ):
215
- def __init__ (self ):
216
- super ().__init__ ()
217
- self .padding = padding
218
-
219
- def forward (self , x : torch .Tensor ):
220
- return F .pad (x , self .padding )
221
-
222
- model = Padding ()
223
- x = torch .randn (shape )
224
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
225
-
226
- p = RemoveZeroSizedConstantPadNd ()
227
-
228
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
229
-
191
+ builder = GraphBuilder ()
192
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
193
+ pad = builder .call_operator (
194
+ op = exir_ops .edge .aten .constant_pad_nd .default , args = (x , padding )
195
+ )
196
+ builder .output ([pad ])
197
+ original = builder .get_graph_module ()
198
+ graph_after_passes = cast (
199
+ PassResult , RemoveZeroSizedConstantPadNd ()(original )
200
+ ).graph_module
230
201
self .assertEqual (
231
202
count_node (graph_after_passes , exir_ops .edge .aten .constant_pad_nd .default ),
232
203
0 ,
233
204
)
234
205
235
206
def test_remove_expand (self ):
236
- class Expand (torch .nn .Module ):
237
- def forward (self , x ):
238
- return torch .ops .aten .expand_copy (x , [2 , 3 , 5 ])
239
-
240
- x = torch .ones (2 , 3 , 5 )
241
- p = RemoveNopExpandOpPass ()
242
- graph_module = export_to_edge (Expand (), (x ,)).exported_program ().graph_module
243
- graph_module = p (graph_module ).graph_module
244
- # Assert that expand op is optimized away, since it is a nop
207
+ builder = GraphBuilder ()
208
+ x = builder .placeholder ("x" , torch .randn ([2 , 3 , 5 ], dtype = torch .float32 ))
209
+ expand = builder .call_operator (
210
+ op = exir_ops .edge .aten .expand_copy .default , args = (x , [2 , 3 , 5 ])
211
+ )
212
+ builder .output ([expand ])
213
+ original = builder .get_graph_module ()
214
+ graph_after_passes = cast (
215
+ PassResult , RemoveNopExpandOpPass ()(original )
216
+ ).graph_module
245
217
self .assertEqual (
246
- count_node (graph_module , exir_ops .edge .aten .expand_copy .default ), 0
218
+ count_node (graph_after_passes , exir_ops .edge .aten .expand_copy .default ), 0
247
219
)
248
220
249
221
def test_remove_zero_arg_cat (self ):
250
- class Cat (torch .nn .Module ):
251
- def forward (self , x , y ):
252
- return torch .ops .aten .cat ((x , y ), 0 )
253
-
254
- x = torch .ones (1 , 0 , 3 , 5 )
255
- y = torch .ones (2 , 0 , 3 , 5 )
256
- graph_module = (
257
- compiler .export_to_cadence (Cat (), (x , y )).exported_program ().graph_module
258
- )
259
- # Assert that cat op is optimized away, since it concatenates
260
- # two zero-sized tensors
261
- self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 0 )
262
-
263
- def test_remove_single_arg_cat (self ):
264
- class Cat (torch .nn .Module ):
265
- def forward (self , x , y ):
266
- z = torch .ones (0 , 5 )
267
- # z is an empty tensor, and concatenation of x with z will
268
- # be x. So we can safely eliminate the following cat op.
269
- x1 = torch .ops .aten .cat ((x , z ))
270
- x2 = torch .add (x1 , 2.4 , 3.1 )
271
- y1 = torch .add (y , 1 , 2 )
272
- return torch .add (x2 , y1 )
273
-
274
- x = torch .ones (3 , 5 )
275
- y = torch .ones (3 , 5 )
276
- graph_module = export_to_edge (Cat (), (x , y )).exported_program ().graph_module
277
- new_graph_module = RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module
278
- new_graph_module .graph .eliminate_dead_code ()
279
- # Assert that x1 is optimized away
280
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
281
-
282
- def test_remove_zero_sized_cat (self ):
283
- class Cat (torch .nn .Module ):
284
- def __init__ (self , dim : int ):
285
- super ().__init__ ()
286
- self .dim = dim
287
-
288
- def forward (self , tensors ):
289
- return torch .cat (tensors , self .dim )
290
-
291
- shapes , dim , dtype , _max = [(1 , 0 , 3 ), (2 , 0 , 3 )], 0 , torch .float32 , 127
292
-
293
- in_tensors = [(torch .rand (shape ) * _max ).to (dtype = dtype ) for shape in shapes ]
294
-
295
- model = Cat (dim )
296
- graph_module = (
297
- export_to_edge (model , (in_tensors ,)).exported_program ().graph_module
222
+ builder = GraphBuilder ()
223
+ x = builder .placeholder ("x" , torch .randn ([1 , 0 , 3 , 5 ], dtype = torch .float32 ))
224
+ y = builder .placeholder ("y" , torch .randn ([2 , 0 , 3 , 5 ], dtype = torch .float32 ))
225
+ concat = builder .call_operator (
226
+ op = exir_ops .edge .aten .cat .default , args = ([x , y ], 0 )
227
+ )
228
+ builder .output ([concat ])
229
+ original = builder .get_graph_module ()
230
+ graph_after_passes = cast (
231
+ PassResult , RemoveZeroSizedCatArgsPass ()(original )
232
+ ).graph_module
233
+ self .assertEqual (
234
+ count_node (graph_after_passes , exir_ops .edge .aten .cat .default ), 0
298
235
)
299
- new_graph_module = RemoveZeroSizedCatArgsPass ()(graph_module ).graph_module
300
- new_graph_module .graph .eliminate_dead_code ()
301
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
302
236
303
237
def test_remove_clone (self ):
304
238
class Clone (torch .nn .Module ):
0 commit comments