Skip to content

Commit 1420d2e

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
add pass to remove cat from slice pass
Summary: only keep the cat before slice iff the slice range overlaps with *both* tensors in cat. TODO: trace to 1+ level for more cat in a china TODO: support >2 tensors in a cat Differential Revision: D70425971
1 parent eef7b44 commit 1420d2e

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

backends/cadence/aot/remove_ops.py

+65
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,70 @@ def permute_shape(
745745
return [shape[p] for p in permute_dims]
746746

747747

748+
@register_cadence_pass(CadencePassAttribute(opt_level=3))
749+
class RemoveCatFromSliceCopyPass(ExportPass):
750+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
751+
graph = graph_module.graph
752+
slice_copy_nodes = [
753+
node
754+
for node in graph.nodes
755+
if node.target == exir_ops.edge.aten.slice_copy.Tensor
756+
]
757+
for slice_copy_node in slice_copy_nodes:
758+
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
759+
input_node, *other_args = slice_copy_node.args
760+
if len(other_args) >= 1:
761+
slice_dim = other_args[0]
762+
if len(other_args) >= 2:
763+
start_idx = other_args[1]
764+
if len(other_args) >= 3:
765+
end_idx = other_args[2]
766+
if len(other_args) >= 4:
767+
step = other_args[3]
768+
if step != 1:
769+
continue
770+
slice_copy_dtype = slice_copy_node.meta["val"].dtype
771+
if input_node.target != exir_ops.edge.aten.cat.default:
772+
continue
773+
cat_dtype = input_node.meta["val"].dtype
774+
if slice_copy_dtype != cat_dtype:
775+
continue
776+
cat_dim = input_node.args[1:]
777+
if len(cat_dim) == 0:
778+
cat_dim = 0
779+
if cat_dim != slice_dim:
780+
continue
781+
base_idx = 0
782+
cat_input_to_keep = None
783+
for cat_input_node in input_node.args[0]:
784+
cat_dtype = cat_input_node.meta["val"].dtype
785+
if slice_copy_dtype != cat_dtype:
786+
continue
787+
cat_dim = cat_input_node.args[1:]
788+
if len(cat_dim) == 0:
789+
cat_dim = 0
790+
if cat_dim != slice_dim:
791+
continue
792+
cat_shape = cat_input_node.meta["val"].shape
793+
794+
# check if the slice range overlaps with the cat range
795+
if (
796+
max(start_idx, base_idx)
797+
< min(end_idx, list(cat_shape)[cat_dim] + base_idx)
798+
) or (start_idx == end_idx == base_idx):
799+
if cat_input_to_keep is not None:
800+
# need more than one cat inputs, keep the cat
801+
cat_input_to_keep = None
802+
break
803+
cat_input_to_keep = cat_input_node
804+
base_idx += list(cat_shape)[cat_dim]
805+
if cat_input_to_keep is not None:
806+
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
807+
graph_module.recompile()
808+
graph_module.graph.eliminate_dead_code()
809+
return super().call(graph_module)
810+
811+
748812
# The following class consolidates functions to remove ops that are redundant
749813
# in Jarvis. Currently, each function in this class iterates over each node of
750814
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +829,5 @@ class CadenceRemoveNops:
765829
RemoveNopMulOpPass,
766830
RemoveNopAddOpPass,
767831
RemoveNopLinalgVectorNormOpPass,
832+
RemoveCatFromSliceCopyPass,
768833
]

backends/cadence/aot/tests/test_remove_ops_passes.py

+52
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
24+
RemoveCatFromSliceCopyPass,
2425
RemoveCloneOpPass,
2526
RemoveContiguousOpPass,
2627
RemoveDetachCopyPass,
@@ -709,3 +710,54 @@ def forward(self, x):
709710
self.assertEqual(
710711
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
711712
)
713+
714+
def test_remove_cat_from_slice_copy_all_removal(self) -> None:
715+
class M(torch.nn.Module):
716+
def __init__(self):
717+
super().__init__()
718+
719+
def forward(self, x, y):
720+
x1 = torch.cat((x, y), 0) # (2, 4)
721+
return torch.slice_copy(x1, dim=0, start=0, end=1)
722+
723+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
724+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
725+
p = RemoveCatFromSliceCopyPass()
726+
graph_module = cast(PassResult, p(graph_module)).graph_module
727+
728+
# Ensure both cat nodes were removed
729+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
730+
731+
def test_remove_cat_from_slice_copy_no_removal(self) -> None:
732+
class M(torch.nn.Module):
733+
def __init__(self):
734+
super().__init__()
735+
736+
def forward(self, x, y):
737+
x1 = torch.cat((x, y), 0) # (2, 4)
738+
return torch.slice_copy(x1, dim=0, start=0, end=3)
739+
740+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
741+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
742+
p = RemoveCatFromSliceCopyPass()
743+
graph_module = cast(PassResult, p(graph_module)).graph_module
744+
745+
# Ensure both cat nodes were removed
746+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
747+
748+
def test_remove_cat_from_slice_copy_zero_range(self) -> None:
749+
class M(torch.nn.Module):
750+
def __init__(self):
751+
super().__init__()
752+
753+
def forward(self, x, y):
754+
x1 = torch.cat((x, y), 0) # (2, 4)
755+
return torch.slice_copy(x1, dim=0, start=0, end=0)
756+
757+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
758+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
759+
p = RemoveCatFromSliceCopyPass()
760+
graph_module = cast(PassResult, p(graph_module)).graph_module
761+
762+
# Ensure both cat nodes were removed
763+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

0 commit comments

Comments
 (0)