Skip to content

add pass to remove cat from slice pass #8857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,72 @@ def remove_branched(
user.replace_all_uses_with(node.args[0])


class RemoveCatFromSliceCopyPass(ExportPass):
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
slice_copy_nodes = [
node
for node in graph_module.graph.nodes
if node.target == exir_ops.edge.aten.slice_copy.Tensor
]
for slice_copy_node in slice_copy_nodes:
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
input_node, *other_args = slice_copy_node.args
if len(other_args) >= 1:
slice_dim = other_args[0]
if len(other_args) >= 2:
start_idx = other_args[1]
if len(other_args) >= 3:
end_idx = other_args[2]
if len(other_args) >= 4:
step = other_args[3]
if step != 1:
continue
slice_copy_dtype = slice_copy_node.meta["val"].dtype
if input_node.target != exir_ops.edge.aten.cat.default:
continue
cat_dtype = input_node.meta["val"].dtype
if slice_copy_dtype != cat_dtype:
continue
cat_dim = input_node.args[1:]
if len(cat_dim) == 0:
cat_dim = 0
if cat_dim != slice_dim:
continue
cat_output_shape = input_node.meta["val"].shape
start_idx = (
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
)
end_idx = (
cat_output_shape[cat_dim]
if end_idx > cat_output_shape[cat_dim]
else end_idx
)
base_idx = 0
cat_input_to_keep = None
for cat_input_node in input_node.args[0]:
cat_input_dtype = cat_input_node.meta["val"].dtype
if slice_copy_dtype != cat_input_dtype:
continue
cat_input_shape = cat_input_node.meta["val"].shape

# check if the slice range overlaps with the cat range
if (
base_idx <= start_idx
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
):
cat_input_to_keep = cat_input_node
break
base_idx += list(cat_input_shape)[cat_dim]
if cat_input_to_keep is not None:
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self._remove_unused_cat(graph_module)
graph_module.recompile()
graph_module.graph.eliminate_dead_code()
return super().call(graph_module)


# The following class consolidates functions to remove ops that are redundant
# in Jarvis. Currently, each function in this class iterates over each node of
# the graph module once. In future, we could consolidate them into a monolithic
Expand Down
52 changes: 52 additions & 0 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.cadence.aot.remove_ops import (
RemoveAliasCopyOpPass,
RemoveBranchedQuantDequant,
RemoveCatFromSliceCopyPass,
RemoveCloneOpPass,
RemoveContiguousOpPass,
RemoveDetachCopyPass,
Expand Down Expand Up @@ -741,3 +742,54 @@ def forward(self, x):
},
)
)

def test_remove_cat_from_slice_copy_all_removal(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=1)

inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module

# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

def test_remove_cat_from_slice_copy_no_removal(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=3)

inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module

# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)

def test_remove_cat_from_slice_copy_zero_range(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=0)

inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module

# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
Loading