@@ -745,6 +745,70 @@ def permute_shape(
745
745
return [shape [p ] for p in permute_dims ]
746
746
747
747
748
+ @register_cadence_pass (CadencePassAttribute (opt_level = 3 ))
749
+ class RemoveCatFromSliceCopyPass (ExportPass ):
750
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
751
+ graph = graph_module .graph
752
+ slice_copy_nodes = [
753
+ node
754
+ for node in graph .nodes
755
+ if node .target == exir_ops .edge .aten .slice_copy .Tensor
756
+ ]
757
+ for slice_copy_node in slice_copy_nodes :
758
+ slice_dim , start_idx , end_idx , step = 0 , 0 , float ("inf" ), 1
759
+ input_node , * other_args = slice_copy_node .args
760
+ if len (other_args ) >= 1 :
761
+ slice_dim = other_args [0 ]
762
+ if len (other_args ) >= 2 :
763
+ start_idx = other_args [1 ]
764
+ if len (other_args ) >= 3 :
765
+ end_idx = other_args [2 ]
766
+ if len (other_args ) >= 4 :
767
+ step = other_args [3 ]
768
+ if step != 1 :
769
+ continue
770
+ slice_copy_dtype = slice_copy_node .meta ["val" ].dtype
771
+ if input_node .target != exir_ops .edge .aten .cat .default :
772
+ continue
773
+ cat_dtype = input_node .meta ["val" ].dtype
774
+ if slice_copy_dtype != cat_dtype :
775
+ continue
776
+ cat_dim = input_node .args [1 :]
777
+ if len (cat_dim ) == 0 :
778
+ cat_dim = 0
779
+ if cat_dim != slice_dim :
780
+ continue
781
+ base_idx = 0
782
+ cat_input_to_keep = None
783
+ for cat_input_node in input_node .args [0 ]:
784
+ cat_dtype = cat_input_node .meta ["val" ].dtype
785
+ if slice_copy_dtype != cat_dtype :
786
+ continue
787
+ cat_dim = cat_input_node .args [1 :]
788
+ if len (cat_dim ) == 0 :
789
+ cat_dim = 0
790
+ if cat_dim != slice_dim :
791
+ continue
792
+ cat_shape = cat_input_node .meta ["val" ].shape
793
+
794
+ # check if the slice range overlaps with the cat range
795
+ if (
796
+ max (start_idx , base_idx )
797
+ < min (end_idx , list (cat_shape )[cat_dim ] + base_idx )
798
+ ) or (start_idx == end_idx == base_idx ):
799
+ if cat_input_to_keep is not None :
800
+ # need more than one cat inputs, keep the cat
801
+ cat_input_to_keep = None
802
+ break
803
+ cat_input_to_keep = cat_input_node
804
+ base_idx += list (cat_shape )[cat_dim ]
805
+ if cat_input_to_keep is not None :
806
+ slice_copy_node .replace_input_with (input_node , cat_input_to_keep )
807
+ graph_module .recompile ()
808
+ graph_module .graph .eliminate_dead_code ()
809
+ return super ().call (graph_module )
810
+
811
+
748
812
# The following class consolidates functions to remove ops that are redundant
749
813
# in Jarvis. Currently, each function in this class iterates over each node of
750
814
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +829,5 @@ class CadenceRemoveNops:
765
829
RemoveNopMulOpPass ,
766
830
RemoveNopAddOpPass ,
767
831
RemoveNopLinalgVectorNormOpPass ,
832
+ RemoveCatFromSliceCopyPass ,
768
833
]
0 commit comments