@@ -224,11 +224,13 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224
224
225
225
# Initializes the nodes metadata and runs the GenerateMemoryViewConstraints,
226
226
# GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
227
- def run_memory_planning (self , original , alloc_graph_input = True ) -> GraphModule :
227
+ def run_memory_planning (
228
+ self , original , opt_level = 2 , alloc_graph_input = True
229
+ ) -> GraphModule :
228
230
graph_module = SpecPropPass ().call (original ).graph_module
229
231
return CadenceMemoryPlanning (
230
232
get_default_memory_config (),
231
- opt_level = 2 ,
233
+ opt_level = opt_level ,
232
234
mem_algo = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
233
235
alloc_graph_input = alloc_graph_input ,
234
236
)(graph_module ).graph_module
@@ -535,130 +537,239 @@ def test_optimize_cat_with_slice_infeasible(self) -> None:
535
537
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
536
538
self .verify_nop_memory_alloc (graph_module )
537
539
538
- def test_optimize_slice_Tensor (self ) -> None :
539
- class SliceTensor (torch .nn .Module ):
540
- def forward (self , x , y , z ):
541
- x1 = torch .add (x , 2.4 , 3.1 )
542
- # This slice should always be optimized, since x1 is not placeholder
543
- # and the slice is along the outermost dim
544
- t1 = torch .ops .aten .slice (x1 , 0 , 1 , 2 )
545
- # This slice should not be optimized when alloc_graph_input=False,
546
- # since y is a placeholder node
547
- t2 = torch .ops .aten .slice (y , 0 , 0 , 1 )
548
- # This slice should be always optimized, since the dims before
549
- # sliced dims are 1
550
- z1 = torch .add (z , 2.4 , 3.1 )
551
- t3 = torch .ops .aten .slice (z1 , 1 , 4 , 5 )
552
- return (t1 + t2 ) * t3
553
-
554
- x = torch .ones (3 , 6 )
555
- y = torch .ones (2 , 6 )
556
- z = torch .ones (1 , 6 )
557
- # Run the memory planning pass and get the graph module
558
- graph_module = (
559
- compiler .export_to_executorch_gen_etrecord (
560
- SliceTensor (),
561
- (x , y , z ),
562
- opt_level = 2 ,
563
- mem_algo = 1 ,
564
- alloc_graph_input = False ,
565
- )
566
- .exported_program ()
567
- .graph_module
540
+ def test_optimize_slice_outermost (self ) -> None :
541
+ builder = GraphBuilder ()
542
+ x = builder .placeholder ("x" , torch .ones (3 , 6 , dtype = torch .float32 ))
543
+ to_add_to_x = builder .call_operator (
544
+ op = exir_ops .edge .aten .full .default ,
545
+ args = ([3 , 6 ], 123.0 ),
546
+ kwargs = {"dtype" : torch .float32 },
547
+ )
548
+ add_x = builder .call_operator (
549
+ op = exir_ops .edge .aten .add .Tensor ,
550
+ args = (x , to_add_to_x ),
551
+ )
552
+ slice_out = builder .call_operator (
553
+ op = exir_ops .edge .aten .full .default ,
554
+ args = ([1 , 6 ], 0.0 ),
555
+ kwargs = {"dtype" : torch .float32 },
568
556
)
557
+ # This slice should always be optimized, since add_x is not placeholder
558
+ # and the slice is along the outermost dim
559
+ slice_result = builder .call_operator (
560
+ op = torch .ops .aten .slice_copy .Tensor_out ,
561
+ args = (
562
+ add_x ,
563
+ 0 , # dim
564
+ 1 , # start
565
+ 2 , # end
566
+ 1 , # step
567
+ ),
568
+ kwargs = {"out" : slice_out },
569
+ )
570
+ builder .output ([slice_result ])
571
+ original = builder .get_graph_module ()
572
+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
569
573
graph_module .graph .eliminate_dead_code ()
570
- # Assert that t2 is not optimized away
571
574
self .assertEqual (
572
- count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ), 1
575
+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
576
+ )
577
+ self .verify_nop_memory_alloc (graph_module )
578
+
579
+ def test_optimize_slice_non_outermost (self ) -> None :
580
+ builder = GraphBuilder ()
581
+ x = builder .placeholder ("x" , torch .ones (1 , 6 , dtype = torch .float32 ))
582
+ to_add_to_x = builder .call_operator (
583
+ op = exir_ops .edge .aten .full .default ,
584
+ args = ([1 , 6 ], 123.0 ),
585
+ kwargs = {"dtype" : torch .float32 },
586
+ )
587
+ add_x = builder .call_operator (
588
+ op = exir_ops .edge .aten .add .Tensor ,
589
+ args = (x , to_add_to_x ),
590
+ )
591
+ slice_out = builder .call_operator (
592
+ op = exir_ops .edge .aten .full .default ,
593
+ args = ([1 , 2 ], 0.0 ),
594
+ kwargs = {"dtype" : torch .float32 },
595
+ )
596
+ # This slice should be always optimized, since the dims before
597
+ # sliced dims are 1.
598
+ slice_result = builder .call_operator (
599
+ op = torch .ops .aten .slice_copy .Tensor_out ,
600
+ args = (
601
+ add_x ,
602
+ 1 , # dim
603
+ 4 , # start
604
+ 6 , # end
605
+ 1 , # step
606
+ ),
607
+ kwargs = {"out" : slice_out },
573
608
)
574
- # Assert that t1 and t3 are optimized to slice_copy_nop veresion
609
+ builder .output ([slice_result ])
610
+ original = builder .get_graph_module ()
611
+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
612
+ graph_module .graph .eliminate_dead_code ()
575
613
self .assertEqual (
576
- count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
614
+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
577
615
)
616
+ self .verify_nop_memory_alloc (graph_module )
617
+
618
+ def test_optimize_slice_depending_on_opt_level (self ) -> None :
619
+ builder = GraphBuilder ()
620
+ x = builder .placeholder ("x" , torch .ones (2 , 6 , dtype = torch .float32 ))
621
+ slice_out = builder .call_operator (
622
+ op = exir_ops .edge .aten .full .default ,
623
+ args = ([1 , 6 ], 0.0 ),
624
+ kwargs = {"dtype" : torch .float32 },
625
+ )
626
+ # This slice should not be optimized when alloc_graph_input=False,
627
+ # since y is a placeholder node
628
+ slice_result = builder .call_operator (
629
+ op = torch .ops .aten .slice_copy .Tensor_out ,
630
+ args = (
631
+ x ,
632
+ 0 , # dim
633
+ 0 , # start
634
+ 1 , # end
635
+ 1 , # step
636
+ ),
637
+ kwargs = {"out" : slice_out },
638
+ )
639
+ builder .output ([slice_result ])
640
+ original = builder .get_graph_module ()
641
+ graph_module = self .run_memory_planning (
642
+ original , opt_level = 2 , alloc_graph_input = False
643
+ )
644
+ graph_module .graph .eliminate_dead_code ()
645
+ self .assertEqual (
646
+ count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ), 1
647
+ )
648
+ self .verify_nop_memory_alloc (graph_module )
649
+
578
650
# When we compile with alloc_graph_input=True, all the slice ops must
579
- # be optimized.
580
- # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
581
- # pass to run:
582
- graph_module = (
583
- compiler .export_to_executorch_gen_etrecord (
584
- SliceTensor (),
585
- (x , y , z ),
586
- opt_level = 3 ,
587
- mem_algo = 1 ,
588
- alloc_graph_input = True ,
589
- )
590
- .exported_program ()
591
- .graph_module
651
+ # be optimized, which is available only at opt_level 2+.
652
+ graph_module = self .run_memory_planning (
653
+ original , opt_level = 3 , alloc_graph_input = True
592
654
)
593
655
graph_module .graph .eliminate_dead_code ()
594
- self .assertFalse (count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ))
595
656
self .assertEqual (
596
- count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
657
+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
597
658
)
598
659
self .verify_nop_memory_alloc (graph_module )
599
660
600
- def test_optimize_select_Tensor (self ) -> None :
601
- class SelectTensor (torch .nn .Module ):
602
- def forward (self , x , y , z ):
603
- x1 = torch .add (x , 2.4 , 3.1 )
604
- # This select should always be optimized, since x1 is not
605
- # placeholder, and the select is along the outermost dim
606
- t1 = torch .select_copy (x1 , 0 , 1 )
607
- # This select should not be optimized if alloc_graph_input=False,
608
- # since y is a placeholder node.
609
- t2 = torch .select_copy (y , 0 , 0 )
610
- # This select should always be optimized, since the dims before
611
- # select dims are 1
612
- z1 = torch .add (z , 2.4 , 3.1 )
613
- t3 = torch .select (z1 , 1 , 4 )
614
- return (t1 + t2 ) * t3
615
-
616
- x = torch .ones (3 , 6 )
617
- y = torch .ones (2 , 6 )
618
- z = torch .ones (1 , 6 )
619
- # Optimizing select ops is only at opt_level 2+, and requires the memory planning
620
- # pass to run:
621
- graph_module = (
622
- compiler .export_to_executorch_gen_etrecord (
623
- SelectTensor (),
624
- (x , y , z ),
625
- opt_level = 2 ,
626
- mem_algo = 1 ,
627
- alloc_graph_input = False ,
628
- )
629
- .exported_program ()
630
- .graph_module
661
+ def test_optimize_select_outermost (self ) -> None :
662
+ builder = GraphBuilder ()
663
+ x = builder .placeholder ("x" , torch .ones (3 , 6 , dtype = torch .float32 ))
664
+ to_add_to_x = builder .call_operator (
665
+ op = exir_ops .edge .aten .full .default ,
666
+ args = ([3 , 6 ], 123.0 ),
667
+ kwargs = {"dtype" : torch .float32 },
631
668
)
669
+ add_x = builder .call_operator (
670
+ op = exir_ops .edge .aten .add .Tensor ,
671
+ args = (x , to_add_to_x ),
672
+ )
673
+ slice_out = builder .call_operator (
674
+ op = exir_ops .edge .aten .full .default ,
675
+ args = ([1 , 6 ], 0.0 ),
676
+ kwargs = {"dtype" : torch .float32 },
677
+ )
678
+ # This select should always be optimized, since add_x is not placeholder
679
+ # and the select is along the outermost dim
680
+ slice_result = builder .call_operator (
681
+ op = torch .ops .aten .select_copy .int_out ,
682
+ args = (
683
+ add_x ,
684
+ 0 , # dim
685
+ 1 , # index
686
+ ),
687
+ kwargs = {"out" : slice_out },
688
+ )
689
+ builder .output ([slice_result ])
690
+ original = builder .get_graph_module ()
691
+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
632
692
graph_module .graph .eliminate_dead_code ()
633
- # Assert that t2 is not optimized away
634
693
self .assertEqual (
635
- count_node (graph_module , torch .ops .aten .select_copy .int_out ), 1
694
+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
695
+ )
696
+ self .verify_nop_memory_alloc (graph_module )
697
+
698
+ def test_optimize_select_non_outermost (self ) -> None :
699
+ builder = GraphBuilder ()
700
+ x = builder .placeholder ("x" , torch .ones (1 , 6 , dtype = torch .float32 ))
701
+ to_add_to_x = builder .call_operator (
702
+ op = exir_ops .edge .aten .full .default ,
703
+ args = ([1 , 6 ], 123.0 ),
704
+ kwargs = {"dtype" : torch .float32 },
705
+ )
706
+ add_x = builder .call_operator (
707
+ op = exir_ops .edge .aten .add .Tensor ,
708
+ args = (x , to_add_to_x ),
709
+ )
710
+ slice_out = builder .call_operator (
711
+ op = exir_ops .edge .aten .full .default ,
712
+ args = ([1 , 2 ], 0.0 ),
713
+ kwargs = {"dtype" : torch .float32 },
714
+ )
715
+ # This select should always be optimized, since the dims before
716
+ # select dims are 1
717
+ slice_result = builder .call_operator (
718
+ op = torch .ops .aten .select_copy .int_out ,
719
+ args = (
720
+ add_x ,
721
+ 1 , # dim
722
+ 4 , # index
723
+ ),
724
+ kwargs = {"out" : slice_out },
636
725
)
637
- # Assert that t1 and t3 are optimized to select_copy_nop veresion
726
+ builder .output ([slice_result ])
727
+ original = builder .get_graph_module ()
728
+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
729
+ graph_module .graph .eliminate_dead_code ()
638
730
self .assertEqual (
639
- count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 2
731
+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
640
732
)
641
- # When we compile with alloc_graph_input=True, all the select ops must
642
- # be optimized.
643
- # Optimizing select ops is only at opt_level 2+, and requires the memory planning
644
- # pass to run:
645
- graph_module = (
646
- compiler .export_to_executorch_gen_etrecord (
647
- SelectTensor (),
648
- (x , y , z ),
649
- opt_level = 3 ,
650
- mem_algo = 1 ,
651
- alloc_graph_input = True ,
652
- )
653
- .exported_program ()
654
- .graph_module
733
+ self .verify_nop_memory_alloc (graph_module )
734
+
735
+ def test_optimize_select_depending_on_opt_level (self ) -> None :
736
+ builder = GraphBuilder ()
737
+ x = builder .placeholder ("x" , torch .ones (2 , 6 , dtype = torch .float32 ))
738
+ slice_out = builder .call_operator (
739
+ op = exir_ops .edge .aten .full .default ,
740
+ args = ([1 , 6 ], 0.0 ),
741
+ kwargs = {"dtype" : torch .float32 },
742
+ )
743
+ # This select should not be optimized if alloc_graph_input=False,
744
+ # since y is a placeholder node.
745
+ slice_result = builder .call_operator (
746
+ op = torch .ops .aten .select_copy .int_out ,
747
+ args = (
748
+ x ,
749
+ 0 , # dim
750
+ 0 , # index
751
+ ),
752
+ kwargs = {"out" : slice_out },
753
+ )
754
+ builder .output ([slice_result ])
755
+ original = builder .get_graph_module ()
756
+ graph_module = self .run_memory_planning (
757
+ original , opt_level = 2 , alloc_graph_input = False
655
758
)
656
759
graph_module .graph .eliminate_dead_code ()
657
760
self .assertEqual (
658
- count_node (graph_module , torch .ops .aten .select_copy .int_out ), 0
761
+ count_node (graph_module , torch .ops .aten .select_copy .int_out ), 1
659
762
)
763
+ self .verify_nop_memory_alloc (graph_module )
764
+
765
+ # When we compile with alloc_graph_input=True, all the slice ops must
766
+ # be optimized, which is available only at opt_level 2+.
767
+ graph_module = self .run_memory_planning (
768
+ original , opt_level = 3 , alloc_graph_input = True
769
+ )
770
+ graph_module .graph .eliminate_dead_code ()
660
771
self .assertEqual (
661
- count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
772
+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
662
773
)
663
774
self .verify_nop_memory_alloc (graph_module )
664
775
0 commit comments