@@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
526526 Fuse a cascaded chain of view ops
527527 """
528528
529- # Find a chain of view ops, and fuse them into a single permute op.
530-
531529 def fuse_cascaded_view_ops (self , graph_module : torch .fx .GraphModule ):
532- graph = graph_module .graph
533- for node in graph .nodes :
534- # We are only interested in view ops
535- if node .target != exir_ops .edge .aten .view_copy .default :
536- continue
537-
538- # Get the cascaded chain of view ops starting at node
539- cascaded_view_ops = get_cascaded_ops (
540- [node ], [exir_ops .edge .aten .view_copy .default ]
541- )
542- # The chain must have more than 1 node
543- if len (cascaded_view_ops ) == 1 :
530+ view_target = exir_ops .edge .aten .view_copy .default
531+ for view_node in graph_module .graph .find_nodes (op = "call_function" , target = view_target , sort = True ):
532+ input_view = view_node .args [0 ]
533+ if input_view .op != "call_function" or input_view .target != view_target :
544534 continue
545535
546- last_view_node = cascaded_view_ops [- 1 ]
547- with graph .inserting_before (last_view_node ):
548- new_view = graph .call_function (
549- exir_ops .edge .aten .view_copy .default ,
550- args = (node .args [0 ], last_view_node .args [1 ]),
551- )
552- last_view_node .replace_all_uses_with (new_view )
553-
554- # Now erase the chain
555- for v in reversed (cascaded_view_ops ):
556- graph .erase_node (v )
536+ view_node .replace_input_with (input_view , input_view .args [0 ])
557537
558538 graph_module .recompile ()
559539
0 commit comments