Skip to content

Commit f203c94

Browse files
authored
Permute elimination pass fixes.
Differential Revision: D74011447 Pull Request resolved: #10662
1 parent 8ee4487 commit f203c94

File tree

3 files changed

+232
-201
lines changed

3 files changed

+232
-201
lines changed

backends/cadence/aot/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ python_unittest(
347347
":compiler",
348348
"//caffe2:torch",
349349
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:graph_builder",
350351
"//executorch/backends/cadence/aot:ops_registrations",
351352
"//executorch/backends/cadence/aot:pass_utils",
352353
"//executorch/backends/cadence/aot:remove_ops",

backends/cadence/aot/remove_ops.py

+138-175
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717
# in a context outside of Jarvis', so exercise caution while invoking this in a
1818
# pass list outside of Jarvis.
1919

20-
import itertools
2120
import logging
2221
from dataclasses import dataclass, field
23-
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
22+
from typing import cast, List, Optional, Sequence
2423

2524
import torch
2625
import torch.fx
@@ -538,211 +537,175 @@ def call_operator(
538537
return super().call_operator(op, args, kwargs, meta)
539538

540539

541-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
540+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
542541
class RemovePermutesAroundElementwiseOps(ExportPass):
543542
"""
544543
Looks for subgraphs of elementwise ops sandwiched between permutes and removes those
545-
permutes if possible. This pass is targeted at models where delegated subgraphs
546-
must be in NHWC format, so there's usually a to_NHWC permute before each delegate and
547-
a to_NCHW permute after it. If all the ops between two delegates are elementwise ops
548-
then these permutes can be safely removed.
549-
Allows special handling for certain non-elementwise ops that can be easily updated based on
550-
the permute's parameter, such as mean and cat
544+
permutes if possible.
545+
Allows special handling for certain non-elementwise ops that can be easily updated
546+
based on the permute's parameter such as mean, cat, and slice.
551547
"""
552548

553549
@dataclass()
554550
class Subgraph:
555-
"""
556-
Keeps track of nodes grouped as a subgraph between two sets of permutes
557-
"""
558-
559-
start_permutes: set[torch.fx.Node] = field(default_factory=set)
560-
end_permutes: set[torch.fx.Node] = field(default_factory=set)
561-
intermediate_nodes: set[torch.fx.Node] = field(default_factory=set)
562-
is_valid: bool = True
563-
564-
elementwise_ops: set[EdgeOpOverload] = {
551+
start_permute: list[int]
552+
end_permute: list[int]
553+
# Nodes in the subgraph, does not include permutes.
554+
nodes: set[torch.fx.Node] = field(default_factory=set)
555+
# Incoming edges to the subgraph from permute nodes.
556+
edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
557+
# Outgoing edges of the subgraph to permute nodes.
558+
edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
559+
560+
permutable_ops: set[EdgeOpOverload] = {
565561
exir_ops.edge.aten.add.Tensor,
566562
exir_ops.edge.aten.mul.Tensor,
567-
exir_ops.edge.aten.mean.dim,
568-
exir_ops.edge.aten.cat.default,
569563
exir_ops.edge.aten.hardtanh.default,
570564
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
571565
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
572566
exir_ops.edge.cadence.quantize_per_tensor.default,
573567
exir_ops.edge.cadence.dequantize_per_tensor.default,
568+
# Ops that require special handling.
569+
exir_ops.edge.aten.cat.default,
570+
exir_ops.edge.aten.mean.dim,
571+
exir_ops.edge.aten.slice_copy.Tensor,
574572
}
575573

576-
# must be initialized in the constructor
577-
special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {}
578-
579-
to_NCHW = [0, 3, 1, 2]
580-
to_NHWC = [0, 2, 3, 1]
581-
582-
def __init__(self) -> None:
583-
super().__init__()
584-
self.visited: set[object] = set()
585-
self.special_handling = {
586-
exir_ops.edge.aten.mean.dim: self.handle_mean_dim,
587-
exir_ops.edge.aten.cat.default: self.handle_cat,
588-
}
589-
590574
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
591-
self.visited = set()
575+
subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = []
576+
processed_nodes: set[torch.fx.Node] = set()
592577
for node in graph_module.graph.nodes:
593-
sg = self.Subgraph()
594-
self.start_search(node, sg)
595-
if self.is_valid_subgraph(sg):
596-
logging.debug(f"Found valid subgraph: {sg}")
597-
self.handle_subgraph(graph_module, sg)
578+
if node.target != exir_ops.edge.aten.permute_copy.default:
579+
continue
598580

599-
result = super().call(graph_module)
600-
return result
581+
start_permute = self.get_permutation(node)
582+
# Expected end permutation for the subgraph.
583+
end_permute = [start_permute.index(i) for i in range(len(start_permute))]
601584

602-
def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None:
603-
assert mean_dim.target == exir_ops.edge.aten.mean.dim
604-
args = list(mean_dim.args)
605-
args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])]
606-
mean_dim.args = tuple(args)
585+
for user in node.users:
586+
if user.target not in self.permutable_ops:
587+
continue
588+
# Create a separate subgraph for each user since there may be cases
589+
# where only a portion of the users are permutable.
590+
subgraph = self.Subgraph(start_permute, end_permute)
591+
if self.visit(user, subgraph, processed_nodes):
592+
subgraphs_found.append(subgraph)
593+
for node in subgraph.nodes:
594+
processed_nodes.add(node)
607595

608-
def handle_cat(self, cat: torch.fx.Node) -> None:
609-
assert cat.target == exir_ops.edge.aten.cat.default
610-
args = list(cat.args)
611-
args[1] = self.to_NCHW[cast(int, args[1])]
612-
cat.args = tuple(args)
596+
for subgraph in subgraphs_found:
597+
self.permute_subgraph(subgraph)
613598

614-
def is_valid_subgraph(self, sg: Subgraph) -> bool:
615-
return (
616-
sg.is_valid
617-
and len(sg.start_permutes) > 0
618-
and len(sg.end_permutes) > 0
619-
and len(sg.intermediate_nodes) > 0
620-
)
599+
graph_module.graph.eliminate_dead_code()
600+
graph_module.recompile()
621601

622-
def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None:
623-
for permute in itertools.chain(sg.start_permutes, sg.end_permutes):
624-
permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6]
602+
return super().call(graph_module)
625603

626-
for node in sg.intermediate_nodes:
627-
if node.target in self.special_handling:
628-
self.special_handling[node.target](node)
604+
def visit(
605+
self,
606+
node: torch.fx.Node,
607+
subgraph: Subgraph,
608+
processed_nodes: set[torch.fx.Node],
609+
) -> bool:
610+
if node in subgraph.nodes:
611+
return True
612+
if node in processed_nodes or not self.is_node_permutable(node):
613+
return False
614+
subgraph.nodes.add(node)
615+
616+
# Traverse downstream:
617+
for user in node.users:
618+
# Output should either go to a matching permute or another permutable op.
619+
if user.target == exir_ops.edge.aten.permute_copy.default:
620+
if self.get_permutation(user) != subgraph.end_permute:
621+
return False
622+
subgraph.edges_out.add((node, user))
623+
elif not self.visit(user, subgraph, processed_nodes):
624+
return False
629625

630-
graph_module.recompile()
631-
graph_module.graph.eliminate_dead_code()
626+
# Traverse upstream:
627+
for inp in node.all_input_nodes:
628+
# Input should either come from a matching permute or another permutable op.
629+
if inp.target == exir_ops.edge.aten.permute_copy.default:
630+
if self.get_permutation(inp) != subgraph.start_permute:
631+
return False
632+
subgraph.edges_in.add((inp, node))
633+
elif not self.visit(inp, subgraph, processed_nodes):
634+
return False
632635

633-
def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None:
634-
if node in self.visited:
635-
return
636+
return True
636637

637-
if self.is_starting_permute(node):
638-
sg.start_permutes.add(node)
639-
self.visited.add(node)
640-
for user in node.users:
641-
self.search_down(user, sg)
642-
643-
def search_up(self, node: object, sg: Subgraph) -> None:
644-
# non-nodes can be ignored. These would be arguments like integers or lists
645-
# of integers, which don't affect the subgraph validity or inclusion set.
646-
if not isinstance(node, torch.fx.Node):
647-
return
648-
649-
if node.op == "placeholder":
650-
# If we reach a placeholder or other terminal node without encountering
651-
# a start permute, then the subgraph is invalid.
652-
# This could be because in the add(x, y) case where x is permuted and
653-
# y is a graph input, we can't remove the permute on x because it might
654-
# become two different shapes that don't broadcast together.
655-
# TODO: Adding a permute on y could be the more optimal solution,
656-
# but perhaps not in all cases, say if x is small and y is very large.
657-
# This transform prefers to be safe over optimal for now.
658-
sg.is_valid = False
659-
return
660-
661-
if node in self.visited:
662-
return
663-
664-
self.visited.add(node)
665-
666-
if self.is_starting_permute(node):
667-
sg.start_permutes.add(node)
668-
for user in node.users:
669-
self.search_down(user, sg)
670-
else:
671-
self.traverse_intermediate_node(node, sg)
672-
673-
def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None:
674-
if node in self.visited or self.is_starting_permute(node):
675-
return
676-
677-
self.visited.add(node)
678-
679-
if self.is_ending_permute(node):
680-
sg.end_permutes.add(node)
681-
for arg in node.args:
682-
if isinstance(arg, list):
683-
for elem in arg:
684-
self.search_up(elem, sg)
685-
else:
686-
self.search_up(arg, sg)
638+
def is_node_permutable(self, node: torch.fx.Node) -> bool:
639+
if node.target not in self.permutable_ops:
640+
return False
641+
if node.target == exir_ops.edge.aten.mean.dim:
642+
# keepdim should be True.
643+
if len(node.args) >= 3:
644+
if not node.args[2]:
645+
return False
646+
elif "keepdim" in node.kwargs:
647+
if not node.kwargs["keepdim"]:
648+
return False
649+
else:
650+
# Default keepdim is False.
651+
return False
652+
return True
653+
654+
def permute_subgraph(self, subgraph: Subgraph) -> None:
655+
# Skip incoming permutes.
656+
for inp, out in subgraph.edges_in:
657+
assert inp.target == exir_ops.edge.aten.permute_copy.default
658+
if len(inp.args) >= 1:
659+
out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0]))
660+
else:
661+
out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"]))
662+
663+
# Skip outgoing permutes.
664+
for inp, out in subgraph.edges_out:
665+
assert out.target == exir_ops.edge.aten.permute_copy.default
666+
out.replace_all_uses_with(inp)
667+
668+
# Handle dimension related node arguments.
669+
for node in subgraph.nodes:
670+
if node.target == exir_ops.edge.aten.cat.default:
671+
self.update_cat(node, subgraph.start_permute)
672+
elif node.target == exir_ops.edge.aten.mean.dim:
673+
self.update_mean_dim(node, subgraph.start_permute)
674+
elif node.target == exir_ops.edge.aten.slice_copy.Tensor:
675+
self.update_slice_copy(node, subgraph.start_permute)
676+
677+
def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
678+
if len(node.args) >= 2:
679+
node.update_arg(1, start_permute[cast(int, node.args[1])])
680+
elif "dim" in node.kwargs:
681+
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
687682
else:
688-
self.traverse_intermediate_node(node, sg)
689-
690-
def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None:
691-
if node.target in self.elementwise_ops:
692-
sg.intermediate_nodes.add(node)
693-
for arg in node.args:
694-
if isinstance(arg, list):
695-
for elem in arg:
696-
self.search_up(elem, sg)
697-
else:
698-
self.search_up(arg, sg)
699-
700-
for user in node.users:
701-
self.search_down(user, sg)
683+
# Default cat dim is 0.
684+
node.update_kwarg("dim", start_permute[0])
702685

703-
else:
704-
sg.is_valid = False
705-
706-
def is_starting_permute(self, node: torch.fx.Node) -> bool:
707-
return self.is_boundary_permute(node, self.to_NCHW)
708-
709-
def is_ending_permute(self, node: torch.fx.Node) -> bool:
710-
return self.is_boundary_permute(node, self.to_NHWC)
711-
712-
@staticmethod
713-
def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool:
714-
permute_dims = list(permute_dims)
715-
if node.target == exir_ops.edge.aten.permute_copy.default:
716-
return cast(list[int], node.args[1]) == permute_dims
717-
elif node.target == exir_ops.edge.aten.view_copy.default:
718-
# If there's a view node, check if it's swapping two dimensions and
719-
# not splitting any others from the input shape.
720-
inp = node.args[0]
721-
if not isinstance(inp, torch.fx.Node):
722-
return False
723-
input_shape = inp.meta["val"].shape
724-
output_shape = node.args[1]
725-
assert isinstance(output_shape, (tuple, list))
726-
# If the shapes are equal in length, no dimension is being split or
727-
# grouped. Then check if a permute of the input shape results in the output shape.
728-
return (
729-
len(input_shape) == len(output_shape)
730-
and len(input_shape) == len(permute_dims)
731-
and RemovePermutesAroundElementwiseOps.permute_shape(
732-
input_shape, permute_dims
733-
)
734-
== output_shape
686+
def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None:
687+
if len(node.args) >= 2:
688+
node.update_arg(
689+
1, [start_permute[dim] for dim in cast(list[int], node.args[1])]
735690
)
736691
else:
737-
return False
692+
node.update_kwarg(
693+
"dim",
694+
[start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])],
695+
)
738696

739-
@staticmethod
740-
def permute_shape(
741-
shape: Union[List[int], torch.Size], permute_dims: Iterable[int]
742-
) -> List[int]:
743-
permute_dims = list(permute_dims)
744-
assert len(shape) == len(permute_dims)
745-
return [shape[p] for p in permute_dims]
697+
def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None:
698+
if len(node.args) >= 2:
699+
node.update_arg(1, start_permute[cast(int, node.args[1])])
700+
else:
701+
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
702+
703+
def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
704+
assert permute_node.target == exir_ops.edge.aten.permute_copy.default
705+
if len(permute_node.args) >= 2:
706+
return cast(list[int], permute_node.args[1])
707+
assert "dim" in permute_node.kwargs
708+
return cast(list[int], permute_node.kwargs["dim"])
746709

747710

748711
@register_cadence_pass(CadencePassAttribute(opt_level=1))

0 commit comments

Comments
 (0)