|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Dict, final, List, Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.exir._serialize._named_data_store import NamedDataStore |
| 11 | + |
| 12 | +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult |
| 13 | +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( |
| 14 | + generate_pattern_op_partitions, |
| 15 | +) |
| 16 | + |
| 17 | +from executorch.exir.backend.compile_spec_schema import CompileSpec |
| 18 | +from executorch.exir.backend.partitioner import ( |
| 19 | + DelegationSpec, |
| 20 | + Partitioner, |
| 21 | + PartitionResult, |
| 22 | +) |
| 23 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 24 | +from executorch.exir.graph_module import get_control_flow_submodules |
| 25 | +from torch.export.exported_program import ExportedProgram |
| 26 | +from torch.fx.passes.operator_support import OperatorSupportBase |
| 27 | + |
| 28 | + |
| 29 | +# Backend details are final (cannot be subclassed). |
| 30 | +@final |
| 31 | +class BackendWithNamedDataMap(BackendDetails): |
| 32 | + """ |
| 33 | + Test Backend for Named Data Map Functionality |
| 34 | +
|
| 35 | + This backend returns no processed_bytes, instead it uses |
| 36 | + the named data store and serializes the name of the op |
| 37 | + as the key and the data as its code value |
| 38 | + """ |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def preprocess( |
| 42 | + edge_program: ExportedProgram, |
| 43 | + compile_specs: List[CompileSpec], |
| 44 | + ) -> PreprocessResult: |
| 45 | + op_codes = { |
| 46 | + exir_ops.edge.aten.sin.default: 0, |
| 47 | + exir_ops.edge.aten.add.Tensor: 1, |
| 48 | + exir_ops.edge.aten.sub.Tensor: 2, |
| 49 | + exir_ops.edge.aten.mul.Tensor: 3, |
| 50 | + exir_ops.edge.aten.div.Tensor: 4, |
| 51 | + } |
| 52 | + ndm = NamedDataStore() |
| 53 | + for node in edge_program.graph.nodes: |
| 54 | + if node.op == "call_function": |
| 55 | + if node.target in op_codes.keys(): |
| 56 | + ndm.add_named_data( |
| 57 | + node.target.__name__, bytes(op_codes[node.target]) |
| 58 | + ) |
| 59 | + |
| 60 | + return PreprocessResult( |
| 61 | + processed_bytes=bytes(b""), |
| 62 | + debug_handle_map={}, |
| 63 | + data_store_output=ndm.get_named_data_store_output(), |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +class SimpleOperatorSupport(OperatorSupportBase): |
| 68 | + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: |
| 69 | + return node.op == "call_function" and node.target in [ |
| 70 | + exir_ops.edge.aten.sin.default, |
| 71 | + exir_ops.edge.aten.add.Tensor, |
| 72 | + exir_ops.edge.aten.sub.Tensor, |
| 73 | + exir_ops.edge.aten.mul.Tensor, |
| 74 | + exir_ops.edge.aten.div.Tensor, |
| 75 | + ] |
| 76 | + |
| 77 | + |
| 78 | +@final |
| 79 | +class BackendWithNDMPartitioner(Partitioner): |
| 80 | + def __init__(self) -> None: |
| 81 | + self._op_support = SimpleOperatorSupport() |
| 82 | + self.backend_id = BackendWithNamedDataMap.__name__ |
| 83 | + |
| 84 | + def _partition_gm( |
| 85 | + self, graph_module: torch.fx.GraphModule, id_start: int = 0 |
| 86 | + ) -> Tuple[int, Dict[str, DelegationSpec]]: |
| 87 | + partition_tags: Dict[str, DelegationSpec] = {} |
| 88 | + partition_list = generate_pattern_op_partitions( |
| 89 | + graph_module, op_support=self._op_support |
| 90 | + ) |
| 91 | + |
| 92 | + num_partitions_in_gm = len(partition_list) |
| 93 | + for partition in partition_list: |
| 94 | + curr_par_id = partition.id or 0 |
| 95 | + delegation_tag = f"tag_{curr_par_id + id_start}" |
| 96 | + for node in partition.nodes: |
| 97 | + node.meta["delegation_tag"] = delegation_tag |
| 98 | + delegation_spec = DelegationSpec(self.backend_id, []) |
| 99 | + partition_tags[delegation_tag] = delegation_spec |
| 100 | + |
| 101 | + start_idx_for_submodules = num_partitions_in_gm |
| 102 | + for _, submodule, _ in get_control_flow_submodules(graph_module): |
| 103 | + start_idx_for_submodules, ret_partition_tags = self._partition_gm( |
| 104 | + submodule, start_idx_for_submodules |
| 105 | + ) |
| 106 | + partition_tags.update(ret_partition_tags) |
| 107 | + |
| 108 | + return start_idx_for_submodules, partition_tags |
| 109 | + |
| 110 | + def partition(self, edge_program: ExportedProgram) -> PartitionResult: |
| 111 | + _, partition_tags = self._partition_gm(edge_program.graph_module) |
| 112 | + return PartitionResult( |
| 113 | + tagged_exported_program=edge_program, |
| 114 | + partition_tags=partition_tags, |
| 115 | + ) |
0 commit comments