|
17 | 17 | # in a context outside of Jarvis', so exercise caution while invoking this in a
|
18 | 18 | # pass list outside of Jarvis.
|
19 | 19 |
|
20 |
| -import itertools |
21 | 20 | import logging
|
22 | 21 | from dataclasses import dataclass, field
|
23 |
| -from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union |
| 22 | +from typing import cast, List, Optional, Sequence |
24 | 23 |
|
25 | 24 | import torch
|
26 | 25 | import torch.fx
|
@@ -538,211 +537,175 @@ def call_operator(
|
538 | 537 | return super().call_operator(op, args, kwargs, meta)
|
539 | 538 |
|
540 | 539 |
|
541 |
| -@register_cadence_pass(CadencePassAttribute(opt_level=1)) |
| 540 | +@register_cadence_pass(CadencePassAttribute(opt_level=2)) |
542 | 541 | class RemovePermutesAroundElementwiseOps(ExportPass):
|
543 | 542 | """
|
544 | 543 | Looks for subgraphs of elementwise ops sandwiched between permutes and removes those
|
545 |
| - permutes if possible. This pass is targeted at models where delegated subgraphs |
546 |
| - must be in NHWC format, so there's usually a to_NHWC permute before each delegate and |
547 |
| - a to_NCHW permute after it. If all the ops between two delegates are elementwise ops |
548 |
| - then these permutes can be safely removed. |
549 |
| - Allows special handling for certain non-elementwise ops that can be easily updated based on |
550 |
| - the permute's parameter, such as mean and cat |
| 544 | + permutes if possible. |
| 545 | + Allows special handling for certain non-elementwise ops that can be easily updated |
| 546 | + based on the permute's parameter such as mean, cat, and slice. |
551 | 547 | """
|
552 | 548 |
|
553 | 549 | @dataclass()
|
554 | 550 | class Subgraph:
|
555 |
| - """ |
556 |
| - Keeps track of nodes grouped as a subgraph between two sets of permutes |
557 |
| - """ |
558 |
| - |
559 |
| - start_permutes: set[torch.fx.Node] = field(default_factory=set) |
560 |
| - end_permutes: set[torch.fx.Node] = field(default_factory=set) |
561 |
| - intermediate_nodes: set[torch.fx.Node] = field(default_factory=set) |
562 |
| - is_valid: bool = True |
563 |
| - |
564 |
| - elementwise_ops: set[EdgeOpOverload] = { |
| 551 | + start_permute: list[int] |
| 552 | + end_permute: list[int] |
| 553 | + # Nodes in the subgraph, does not include permutes. |
| 554 | + nodes: set[torch.fx.Node] = field(default_factory=set) |
| 555 | + # Incoming edges to the subgraph from permute nodes. |
| 556 | + edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) |
| 557 | + # Outgoing edges of the subgraph to permute nodes. |
| 558 | + edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) |
| 559 | + |
| 560 | + permutable_ops: set[EdgeOpOverload] = { |
565 | 561 | exir_ops.edge.aten.add.Tensor,
|
566 | 562 | exir_ops.edge.aten.mul.Tensor,
|
567 |
| - exir_ops.edge.aten.mean.dim, |
568 |
| - exir_ops.edge.aten.cat.default, |
569 | 563 | exir_ops.edge.aten.hardtanh.default,
|
570 | 564 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
|
571 | 565 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
|
572 | 566 | exir_ops.edge.cadence.quantize_per_tensor.default,
|
573 | 567 | exir_ops.edge.cadence.dequantize_per_tensor.default,
|
| 568 | + # Ops that require special handling. |
| 569 | + exir_ops.edge.aten.cat.default, |
| 570 | + exir_ops.edge.aten.mean.dim, |
| 571 | + exir_ops.edge.aten.slice_copy.Tensor, |
574 | 572 | }
|
575 | 573 |
|
576 |
| - # must be initialized in the constructor |
577 |
| - special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {} |
578 |
| - |
579 |
| - to_NCHW = [0, 3, 1, 2] |
580 |
| - to_NHWC = [0, 2, 3, 1] |
581 |
| - |
582 |
| - def __init__(self) -> None: |
583 |
| - super().__init__() |
584 |
| - self.visited: set[object] = set() |
585 |
| - self.special_handling = { |
586 |
| - exir_ops.edge.aten.mean.dim: self.handle_mean_dim, |
587 |
| - exir_ops.edge.aten.cat.default: self.handle_cat, |
588 |
| - } |
589 |
| - |
590 | 574 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
591 |
| - self.visited = set() |
| 575 | + subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] |
| 576 | + processed_nodes: set[torch.fx.Node] = set() |
592 | 577 | for node in graph_module.graph.nodes:
|
593 |
| - sg = self.Subgraph() |
594 |
| - self.start_search(node, sg) |
595 |
| - if self.is_valid_subgraph(sg): |
596 |
| - logging.debug(f"Found valid subgraph: {sg}") |
597 |
| - self.handle_subgraph(graph_module, sg) |
| 578 | + if node.target != exir_ops.edge.aten.permute_copy.default: |
| 579 | + continue |
598 | 580 |
|
599 |
| - result = super().call(graph_module) |
600 |
| - return result |
| 581 | + start_permute = self.get_permutation(node) |
| 582 | + # Expected end permutation for the subgraph. |
| 583 | + end_permute = [start_permute.index(i) for i in range(len(start_permute))] |
601 | 584 |
|
602 |
| - def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None: |
603 |
| - assert mean_dim.target == exir_ops.edge.aten.mean.dim |
604 |
| - args = list(mean_dim.args) |
605 |
| - args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])] |
606 |
| - mean_dim.args = tuple(args) |
| 585 | + for user in node.users: |
| 586 | + if user.target not in self.permutable_ops: |
| 587 | + continue |
| 588 | + # Create a separate subgraph for each user since there may be cases |
| 589 | + # where only a portion of the users are permutable. |
| 590 | + subgraph = self.Subgraph(start_permute, end_permute) |
| 591 | + if self.visit(user, subgraph, processed_nodes): |
| 592 | + subgraphs_found.append(subgraph) |
| 593 | + for node in subgraph.nodes: |
| 594 | + processed_nodes.add(node) |
607 | 595 |
|
608 |
| - def handle_cat(self, cat: torch.fx.Node) -> None: |
609 |
| - assert cat.target == exir_ops.edge.aten.cat.default |
610 |
| - args = list(cat.args) |
611 |
| - args[1] = self.to_NCHW[cast(int, args[1])] |
612 |
| - cat.args = tuple(args) |
| 596 | + for subgraph in subgraphs_found: |
| 597 | + self.permute_subgraph(subgraph) |
613 | 598 |
|
614 |
| - def is_valid_subgraph(self, sg: Subgraph) -> bool: |
615 |
| - return ( |
616 |
| - sg.is_valid |
617 |
| - and len(sg.start_permutes) > 0 |
618 |
| - and len(sg.end_permutes) > 0 |
619 |
| - and len(sg.intermediate_nodes) > 0 |
620 |
| - ) |
| 599 | + graph_module.graph.eliminate_dead_code() |
| 600 | + graph_module.recompile() |
621 | 601 |
|
622 |
| - def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None: |
623 |
| - for permute in itertools.chain(sg.start_permutes, sg.end_permutes): |
624 |
| - permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6] |
| 602 | + return super().call(graph_module) |
625 | 603 |
|
626 |
| - for node in sg.intermediate_nodes: |
627 |
| - if node.target in self.special_handling: |
628 |
| - self.special_handling[node.target](node) |
| 604 | + def visit( |
| 605 | + self, |
| 606 | + node: torch.fx.Node, |
| 607 | + subgraph: Subgraph, |
| 608 | + processed_nodes: set[torch.fx.Node], |
| 609 | + ) -> bool: |
| 610 | + if node in subgraph.nodes: |
| 611 | + return True |
| 612 | + if node in processed_nodes or not self.is_node_permutable(node): |
| 613 | + return False |
| 614 | + subgraph.nodes.add(node) |
| 615 | + |
| 616 | + # Traverse downstream: |
| 617 | + for user in node.users: |
| 618 | + # Output should either go to a matching permute or another permutable op. |
| 619 | + if user.target == exir_ops.edge.aten.permute_copy.default: |
| 620 | + if self.get_permutation(user) != subgraph.end_permute: |
| 621 | + return False |
| 622 | + subgraph.edges_out.add((node, user)) |
| 623 | + elif not self.visit(user, subgraph, processed_nodes): |
| 624 | + return False |
629 | 625 |
|
630 |
| - graph_module.recompile() |
631 |
| - graph_module.graph.eliminate_dead_code() |
| 626 | + # Traverse upstream: |
| 627 | + for inp in node.all_input_nodes: |
| 628 | + # Input should either come from a matching permute or another permutable op. |
| 629 | + if inp.target == exir_ops.edge.aten.permute_copy.default: |
| 630 | + if self.get_permutation(inp) != subgraph.start_permute: |
| 631 | + return False |
| 632 | + subgraph.edges_in.add((inp, node)) |
| 633 | + elif not self.visit(inp, subgraph, processed_nodes): |
| 634 | + return False |
632 | 635 |
|
633 |
| - def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None: |
634 |
| - if node in self.visited: |
635 |
| - return |
| 636 | + return True |
636 | 637 |
|
637 |
| - if self.is_starting_permute(node): |
638 |
| - sg.start_permutes.add(node) |
639 |
| - self.visited.add(node) |
640 |
| - for user in node.users: |
641 |
| - self.search_down(user, sg) |
642 |
| - |
643 |
| - def search_up(self, node: object, sg: Subgraph) -> None: |
644 |
| - # non-nodes can be ignored. These would be arguments like integers or lists |
645 |
| - # of integers, which don't affect the subgraph validity or inclusion set. |
646 |
| - if not isinstance(node, torch.fx.Node): |
647 |
| - return |
648 |
| - |
649 |
| - if node.op == "placeholder": |
650 |
| - # If we reach a placeholder or other terminal node without encountering |
651 |
| - # a start permute, then the subgraph is invalid. |
652 |
| - # This could be because in the add(x, y) case where x is permuted and |
653 |
| - # y is a graph input, we can't remove the permute on x because it might |
654 |
| - # become two different shapes that don't broadcast together. |
655 |
| - # TODO: Adding a permute on y could be the more optimal solution, |
656 |
| - # but perhaps not in all cases, say if x is small and y is very large. |
657 |
| - # This transform prefers to be safe over optimal for now. |
658 |
| - sg.is_valid = False |
659 |
| - return |
660 |
| - |
661 |
| - if node in self.visited: |
662 |
| - return |
663 |
| - |
664 |
| - self.visited.add(node) |
665 |
| - |
666 |
| - if self.is_starting_permute(node): |
667 |
| - sg.start_permutes.add(node) |
668 |
| - for user in node.users: |
669 |
| - self.search_down(user, sg) |
670 |
| - else: |
671 |
| - self.traverse_intermediate_node(node, sg) |
672 |
| - |
673 |
| - def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None: |
674 |
| - if node in self.visited or self.is_starting_permute(node): |
675 |
| - return |
676 |
| - |
677 |
| - self.visited.add(node) |
678 |
| - |
679 |
| - if self.is_ending_permute(node): |
680 |
| - sg.end_permutes.add(node) |
681 |
| - for arg in node.args: |
682 |
| - if isinstance(arg, list): |
683 |
| - for elem in arg: |
684 |
| - self.search_up(elem, sg) |
685 |
| - else: |
686 |
| - self.search_up(arg, sg) |
| 638 | + def is_node_permutable(self, node: torch.fx.Node) -> bool: |
| 639 | + if node.target not in self.permutable_ops: |
| 640 | + return False |
| 641 | + if node.target == exir_ops.edge.aten.mean.dim: |
| 642 | + # keepdim should be True. |
| 643 | + if len(node.args) >= 3: |
| 644 | + if not node.args[2]: |
| 645 | + return False |
| 646 | + elif "keepdim" in node.kwargs: |
| 647 | + if not node.kwargs["keepdim"]: |
| 648 | + return False |
| 649 | + else: |
| 650 | + # Default keepdim is False. |
| 651 | + return False |
| 652 | + return True |
| 653 | + |
| 654 | + def permute_subgraph(self, subgraph: Subgraph) -> None: |
| 655 | + # Skip incoming permutes. |
| 656 | + for inp, out in subgraph.edges_in: |
| 657 | + assert inp.target == exir_ops.edge.aten.permute_copy.default |
| 658 | + if len(inp.args) >= 1: |
| 659 | + out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) |
| 660 | + else: |
| 661 | + out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) |
| 662 | + |
| 663 | + # Skip outgoing permutes. |
| 664 | + for inp, out in subgraph.edges_out: |
| 665 | + assert out.target == exir_ops.edge.aten.permute_copy.default |
| 666 | + out.replace_all_uses_with(inp) |
| 667 | + |
| 668 | + # Handle dimension related node arguments. |
| 669 | + for node in subgraph.nodes: |
| 670 | + if node.target == exir_ops.edge.aten.cat.default: |
| 671 | + self.update_cat(node, subgraph.start_permute) |
| 672 | + elif node.target == exir_ops.edge.aten.mean.dim: |
| 673 | + self.update_mean_dim(node, subgraph.start_permute) |
| 674 | + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: |
| 675 | + self.update_slice_copy(node, subgraph.start_permute) |
| 676 | + |
| 677 | + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 678 | + if len(node.args) >= 2: |
| 679 | + node.update_arg(1, start_permute[cast(int, node.args[1])]) |
| 680 | + elif "dim" in node.kwargs: |
| 681 | + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) |
687 | 682 | else:
|
688 |
| - self.traverse_intermediate_node(node, sg) |
689 |
| - |
690 |
| - def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: |
691 |
| - if node.target in self.elementwise_ops: |
692 |
| - sg.intermediate_nodes.add(node) |
693 |
| - for arg in node.args: |
694 |
| - if isinstance(arg, list): |
695 |
| - for elem in arg: |
696 |
| - self.search_up(elem, sg) |
697 |
| - else: |
698 |
| - self.search_up(arg, sg) |
699 |
| - |
700 |
| - for user in node.users: |
701 |
| - self.search_down(user, sg) |
| 683 | + # Default cat dim is 0. |
| 684 | + node.update_kwarg("dim", start_permute[0]) |
702 | 685 |
|
703 |
| - else: |
704 |
| - sg.is_valid = False |
705 |
| - |
706 |
| - def is_starting_permute(self, node: torch.fx.Node) -> bool: |
707 |
| - return self.is_boundary_permute(node, self.to_NCHW) |
708 |
| - |
709 |
| - def is_ending_permute(self, node: torch.fx.Node) -> bool: |
710 |
| - return self.is_boundary_permute(node, self.to_NHWC) |
711 |
| - |
712 |
| - @staticmethod |
713 |
| - def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool: |
714 |
| - permute_dims = list(permute_dims) |
715 |
| - if node.target == exir_ops.edge.aten.permute_copy.default: |
716 |
| - return cast(list[int], node.args[1]) == permute_dims |
717 |
| - elif node.target == exir_ops.edge.aten.view_copy.default: |
718 |
| - # If there's a view node, check if it's swapping two dimensions and |
719 |
| - # not splitting any others from the input shape. |
720 |
| - inp = node.args[0] |
721 |
| - if not isinstance(inp, torch.fx.Node): |
722 |
| - return False |
723 |
| - input_shape = inp.meta["val"].shape |
724 |
| - output_shape = node.args[1] |
725 |
| - assert isinstance(output_shape, (tuple, list)) |
726 |
| - # If the shapes are equal in length, no dimension is being split or |
727 |
| - # grouped. Then check if a permute of the input shape results in the output shape. |
728 |
| - return ( |
729 |
| - len(input_shape) == len(output_shape) |
730 |
| - and len(input_shape) == len(permute_dims) |
731 |
| - and RemovePermutesAroundElementwiseOps.permute_shape( |
732 |
| - input_shape, permute_dims |
733 |
| - ) |
734 |
| - == output_shape |
| 686 | + def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 687 | + if len(node.args) >= 2: |
| 688 | + node.update_arg( |
| 689 | + 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] |
735 | 690 | )
|
736 | 691 | else:
|
737 |
| - return False |
| 692 | + node.update_kwarg( |
| 693 | + "dim", |
| 694 | + [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], |
| 695 | + ) |
738 | 696 |
|
739 |
| - @staticmethod |
740 |
| - def permute_shape( |
741 |
| - shape: Union[List[int], torch.Size], permute_dims: Iterable[int] |
742 |
| - ) -> List[int]: |
743 |
| - permute_dims = list(permute_dims) |
744 |
| - assert len(shape) == len(permute_dims) |
745 |
| - return [shape[p] for p in permute_dims] |
| 697 | + def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: |
| 698 | + if len(node.args) >= 2: |
| 699 | + node.update_arg(1, start_permute[cast(int, node.args[1])]) |
| 700 | + else: |
| 701 | + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) |
| 702 | + |
| 703 | + def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: |
| 704 | + assert permute_node.target == exir_ops.edge.aten.permute_copy.default |
| 705 | + if len(permute_node.args) >= 2: |
| 706 | + return cast(list[int], permute_node.args[1]) |
| 707 | + assert "dim" in permute_node.kwargs |
| 708 | + return cast(list[int], permute_node.kwargs["dim"]) |
746 | 709 |
|
747 | 710 |
|
748 | 711 | @register_cadence_pass(CadencePassAttribute(opt_level=1))
|
|
0 commit comments