Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
Fuse a cascaded chain of view ops
"""

# Find a chain of view ops, and fuse them into a single permute op.

def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in view ops
if node.target != exir_ops.edge.aten.view_copy.default:
continue

# Get the cascaded chain of view ops starting at node
cascaded_view_ops = get_cascaded_ops(
[node], [exir_ops.edge.aten.view_copy.default]
)
# The chain must have more than 1 node
if len(cascaded_view_ops) == 1:
view_target = exir_ops.edge.aten.view_copy.default
for view_node in graph_module.graph.find_nodes(op="call_function", target=view_target, sort=True):
input_view = view_node.args[0]
if input_view.op != "call_function" or input_view.target != view_target:
continue

last_view_node = cascaded_view_ops[-1]
with graph.inserting_before(last_view_node):
new_view = graph.call_function(
exir_ops.edge.aten.view_copy.default,
args=(node.args[0], last_view_node.args[1]),
)
last_view_node.replace_all_uses_with(new_view)

# Now erase the chain
for v in reversed(cascaded_view_ops):
graph.erase_node(v)
view_node.replace_input_with(input_view, input_view.args[0])

graph_module.recompile()

Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ def forward(self, x):
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
)

def test_view_fusion_branched(self):
class ViewFusion(torch.nn.Module):
def forward(self, x):
y = x.view([1, 8, 15])
z = y.view([1, 1, 120])
t = y.view([120, 1, 1])
return z, t

x = torch.randn(8, 5, 3)
graph_module = (
compiler.export_to_cadence(ViewFusion(), (x,))
.exported_program()
.graph_module
)
graph_module.graph.eliminate_dead_code()
# z and t should be fused and y should be eliminated.
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
)

def test_force_quant_dequant_fusion(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down
Loading