Skip to content

Commit b2e0d6e

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Generalize view_copy fusion.
Summary: Implement a simpler and more generalized view_copy fusion that allows branched cases to be fused. Differential Revision: D73443870
1 parent 095722b commit b2e0d6e

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
526526
Fuse a cascaded chain of view ops
527527
"""
528528

529-
# Find a chain of view ops, and fuse them into a single permute op.
530-
531529
def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
532-
graph = graph_module.graph
533-
for node in graph.nodes:
534-
# We are only interested in view ops
535-
if node.target != exir_ops.edge.aten.view_copy.default:
536-
continue
537-
538-
# Get the cascaded chain of view ops starting at node
539-
cascaded_view_ops = get_cascaded_ops(
540-
[node], [exir_ops.edge.aten.view_copy.default]
541-
)
542-
# The chain must have more than 1 node
543-
if len(cascaded_view_ops) == 1:
530+
view_target = exir_ops.edge.aten.view_copy.default
531+
for view_node in graph_module.graph.find_nodes(op="call_function", target=view_target, sort=True):
532+
input_view = view_node.args[0]
533+
if input_view.op != "call_function" or input_view.target != view_target:
544534
continue
545535

546-
last_view_node = cascaded_view_ops[-1]
547-
with graph.inserting_before(last_view_node):
548-
new_view = graph.call_function(
549-
exir_ops.edge.aten.view_copy.default,
550-
args=(node.args[0], last_view_node.args[1]),
551-
)
552-
last_view_node.replace_all_uses_with(new_view)
553-
554-
# Now erase the chain
555-
for v in reversed(cascaded_view_ops):
556-
graph.erase_node(v)
536+
view_node.replace_input_with(input_view, input_view.args[0])
557537

558538
graph_module.recompile()
559539

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,26 @@ def forward(self, x):
220220
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
221221
)
222222

223+
def test_view_fusion_branched(self):
224+
class ViewFusion(torch.nn.Module):
225+
def forward(self, x):
226+
y = x.view([1, 8, 15])
227+
z = y.view([1, 1, 120])
228+
t = y.view([120, 1, 1])
229+
return z, t
230+
231+
x = torch.randn(8, 5, 3)
232+
graph_module = (
233+
compiler.export_to_cadence(ViewFusion(), (x,))
234+
.exported_program()
235+
.graph_module
236+
)
237+
graph_module.graph.eliminate_dead_code()
238+
# z and t should be fused and y should be eliminated.
239+
self.assertEqual(
240+
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
241+
)
242+
223243
def test_force_quant_dequant_fusion(self):
224244
class M(torch.nn.Module):
225245
def __init__(self):

0 commit comments

Comments
 (0)