@@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self):
235
235
)
236
236
237
237
def test_remove_clone (self ):
238
- class Clone (torch .nn .Module ):
239
- def forward (self , x , y ):
240
- t1 = x .clone ()
241
- t2 = y .clone ()
242
- return t1 + t2
243
-
244
- x = torch .ones (3 , 5 )
245
- y = torch .ones (3 , 5 )
246
- graph_module = export_to_edge (Clone (), (x , y )).exported_program ().graph_module
247
- new_graph_module = RemoveCloneOpPass ()(graph_module ).graph_module
248
- new_graph_module .graph .eliminate_dead_code ()
249
- # Assert that t1 and t2 are optimized away
250
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .clone .out ), 0 )
238
+ builder = GraphBuilder ()
239
+ x = builder .placeholder ("x" , torch .randn ([3 , 5 ], dtype = torch .float32 ))
240
+ clone = builder .call_operator (op = exir_ops .edge .aten .clone .default , args = (x ,))
241
+ builder .output ([clone ])
242
+ original = builder .get_graph_module ()
243
+ graph_after_passes = RemoveCloneOpPass ()(original ).graph_module
244
+ self .assertEqual (
245
+ count_node (graph_after_passes , torch .ops .aten .clone .default ), 0
246
+ )
251
247
252
248
def test_remove_contiguous (self ):
253
- class Contiguous ( torch . nn . Module ):
254
- def forward ( self , x , y ):
255
- t1 = x . contiguous ()
256
- t2 = y . contiguous ( )
257
- return t1 + t2
258
-
259
- x = torch . ones ( 3 , 5 )
260
- y = torch . ones ( 3 , 5 )
261
- graph_module = (
262
- export_to_edge ( Contiguous (), ( x , y )). exported_program (). graph_module
249
+ builder = GraphBuilder ()
250
+ x = builder . placeholder ( "x" , torch . randn ([ 3 , 5 ], dtype = torch . float32 ))
251
+ contiguous = builder . call_operator (
252
+ op = exir_ops . edge . aten . contiguous . default , args = ( x , )
253
+ )
254
+ builder . output ([ contiguous ])
255
+ original = builder . get_graph_module ( )
256
+ graph_after_passes = RemoveContiguousOpPass ()( original ). graph_module
257
+ self . assertEqual (
258
+ count_node ( graph_after_passes , torch . ops . aten . contiguous . default ), 0
263
259
)
264
- new_graph_module = RemoveContiguousOpPass ()(graph_module ).graph_module
265
- new_graph_module .graph .eliminate_dead_code ()
266
- # Assert that t1 and t2 are optimized away
267
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .contiguous .out ), 0 )
268
260
269
261
@parameterized .expand (
270
262
[
@@ -274,119 +266,129 @@ def forward(self, x, y):
274
266
)
275
267
@torch .no_grad ()
276
268
def test_remove_nop_view (self , shape , new_shape ):
277
- class View (torch .nn .Module ):
278
- def __init__ (self , new_shape ):
279
- super ().__init__ ()
280
- self .new_shape = new_shape
281
-
282
- def forward (self , x : torch .Tensor ):
283
- return x .view (self .new_shape )
284
-
285
- model = View (new_shape )
286
- x = torch .randn (shape )
287
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
288
- p = RemoveNopSliceOrViewOpPass ()
289
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
290
- graph_after_passes .graph .eliminate_dead_code ()
291
- # Assert that view op was removed
269
+ builder = GraphBuilder ()
270
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
271
+ view = builder .call_operator (
272
+ op = exir_ops .edge .aten .view_copy .default , args = (x , new_shape )
273
+ )
274
+ builder .output ([view ])
275
+ original = builder .get_graph_module ()
276
+ graph_after_passes = cast (
277
+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
278
+ ).graph_module
292
279
self .assertEqual (
293
280
count_node (graph_after_passes , exir_ops .edge .aten .view_copy .default ), 0
294
281
)
295
282
296
283
def test_remove_nop_slice (self ):
297
- class Slice (torch .nn .Module ):
298
- def forward (self , x ):
299
- return torch .slice_copy (x , dim = 0 , start = 0 , step = 1 )
300
-
301
- x = torch .ones (3 , 5 )
302
- model = Slice ()
303
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
304
- p = RemoveNopSliceOrViewOpPass ()
305
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
306
- graph_after_passes .graph .eliminate_dead_code ()
307
- # Assert that slice op was removed
284
+ builder = GraphBuilder ()
285
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
286
+ slice_ = builder .call_operator (
287
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
288
+ args = (
289
+ x ,
290
+ 0 , # dim
291
+ 0 , # start
292
+ 3 , # end
293
+ ),
294
+ )
295
+ builder .output ([slice_ ])
296
+ original = builder .get_graph_module ()
297
+ graph_after_passes = cast (
298
+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
299
+ ).graph_module
308
300
self .assertEqual (
309
301
count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
310
302
)
311
303
312
- def test_remove_nop_select (self ):
313
- class SelectFeasible1 ( torch . nn . Module ):
314
- def forward ( self , x ):
315
- y = x . select ( 0 , 0 )
316
- z = y . view ([ 1 , 5 , 6 ])
317
- return z
318
-
319
- x = torch . ones ( 1 , 5 , 6 )
320
- graph_module = (
321
- export_to_edge ( SelectFeasible1 (), ( x ,)). exported_program (). graph_module
304
+ def test_remove_nop_select_before_view (self ):
305
+ builder = GraphBuilder ()
306
+ x = builder . placeholder ( "x" , torch . randn ( 1 , 5 , 6 , dtype = torch . float32 ))
307
+ select = builder . call_operator (
308
+ op = exir_ops . edge . aten . select_copy . int ,
309
+ args = (
310
+ x ,
311
+ 0 , # dim
312
+ 0 , # index
313
+ ),
322
314
)
323
- self .assertEqual (
324
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
315
+ view = builder .call_operator (
316
+ op = exir_ops .edge .aten .view_copy .default ,
317
+ args = (select , [1 , 5 , 6 ]), # new shape
325
318
)
326
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
327
- # Assert that select op was removed
319
+ builder .output ([view ])
320
+ original = builder .get_graph_module ()
321
+ graph_after_passes = cast (
322
+ PassResult , RemoveNopSelectOpPass ()(original )
323
+ ).graph_module
328
324
self .assertEqual (
329
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
325
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
330
326
)
331
327
332
- class SelectFeasible2 (torch .nn .Module ):
333
- def forward (self , x , y ):
334
- x = x .select (0 , 0 )
335
- z = x + y
336
- return z
337
-
338
- x = torch .ones (1 , 5 , 6 )
339
- y = torch .ones (1 , 5 , 6 )
340
- graph_module = (
341
- export_to_edge (SelectFeasible2 (), (x , y )).exported_program ().graph_module
342
- )
343
- self .assertEqual (
344
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
328
+ def test_remove_nop_select_before_add (self ):
329
+ builder = GraphBuilder ()
330
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
331
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
332
+ select = builder .call_operator (
333
+ op = exir_ops .edge .aten .select_copy .int ,
334
+ args = (
335
+ x ,
336
+ 0 , # dim
337
+ 0 , # index
338
+ ),
345
339
)
346
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
347
- # Assert that select op was removed
340
+ add = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (select , y ))
341
+ builder .output ([add ])
342
+ original = builder .get_graph_module ()
343
+ graph_after_passes = cast (
344
+ PassResult , RemoveNopSelectOpPass ()(original )
345
+ ).graph_module
348
346
self .assertEqual (
349
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
347
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
350
348
)
351
349
352
- class SelectFeasible3 (torch .nn .Module ):
353
- def forward (self , x , y ):
354
- x = x .select (0 , 0 )
355
- z = x * y
356
- return z
357
-
358
- x = torch .ones (1 , 5 , 6 )
359
- y = torch .ones (1 , 5 , 6 )
360
- graph_module = (
361
- export_to_edge (SelectFeasible3 (), (x , y )).exported_program ().graph_module
362
- )
363
- self .assertEqual (
364
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
350
+ def test_remove_nop_select_before_mul (self ):
351
+ builder = GraphBuilder ()
352
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
353
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
354
+ select = builder .call_operator (
355
+ op = exir_ops .edge .aten .select_copy .int ,
356
+ args = (
357
+ x ,
358
+ 0 , # dim
359
+ 0 , # index
360
+ ),
365
361
)
366
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
367
- # Assert that select op was removed
362
+ mul = builder .call_operator (op = exir_ops .edge .aten .mul .Tensor , args = (select , y ))
363
+ builder .output ([mul ])
364
+ original = builder .get_graph_module ()
365
+ graph_after_passes = cast (
366
+ PassResult , RemoveNopSelectOpPass ()(original )
367
+ ).graph_module
368
368
self .assertEqual (
369
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
369
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
370
370
)
371
371
372
- class SelectFeasible4 (torch .nn .Module ):
373
- def forward (self , x , y ):
374
- x = x .select (0 , 0 )
375
- z = x / y
376
- return z
377
-
378
- x = torch .ones (1 , 5 , 6 )
379
- y = torch .ones (1 , 5 , 6 )
380
- graph_module = (
381
- export_to_edge (SelectFeasible4 (), (x , y )).exported_program ().graph_module
382
- )
383
- self .assertEqual (
384
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
372
+ def test_remove_nop_select_before_div (self ):
373
+ builder = GraphBuilder ()
374
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
375
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
376
+ select = builder .call_operator (
377
+ op = exir_ops .edge .aten .select_copy .int ,
378
+ args = (
379
+ x ,
380
+ 0 , # dim
381
+ 0 , # index
382
+ ),
385
383
)
386
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
387
- # Assert that select op was removed
384
+ div = builder .call_operator (op = exir_ops .edge .aten .div .Tensor , args = (select , y ))
385
+ builder .output ([div ])
386
+ original = builder .get_graph_module ()
387
+ graph_after_passes = cast (
388
+ PassResult , RemoveNopSelectOpPass ()(original )
389
+ ).graph_module
388
390
self .assertEqual (
389
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
391
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
390
392
)
391
393
392
394
def test_remove_nop_quant_dequant (self ):
0 commit comments