|
4 | 4 | #
|
5 | 5 |
|
6 | 6 | import logging
|
7 |
| -from typing import Any, Dict, List, Union |
| 7 | +from typing import Any, cast, Dict, List, Union |
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from executorch.backends.apple.mps.mps_preprocess import MPSBackend
|
11 | 11 | from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
|
12 | 12 | from executorch.backends.apple.mps.utils.mps_utils import is_parameter
|
| 13 | +from executorch.backends.transforms import get_shape |
13 | 14 | from executorch.exir.backend.backend_details import CompileSpec
|
14 | 15 | from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
|
15 | 16 | generate_partitions_from_list_of_nodes,
|
|
20 | 21 | PartitionResult,
|
21 | 22 | )
|
22 | 23 | from executorch.exir.backend.utils import tag_constant_data
|
| 24 | +from executorch.exir.dialects._ops import ops as exir_ops |
23 | 25 | from torch.export.exported_program import ExportedProgram
|
24 | 26 | from torch.fx.passes.infra.partitioner import Partition
|
25 | 27 | from torch.fx.passes.operator_support import OperatorSupportBase
|
|
28 | 30 | logging.basicConfig(level=logging.DEBUG, format=FORMAT)
|
29 | 31 |
|
30 | 32 |
|
| 33 | +# ops implemented as Metal kernels. |
| 34 | +METAL_KERNELS = [ |
| 35 | + exir_ops.edge.aten.index.Tensor, |
| 36 | + exir_ops.edge.aten.index_put.default, |
| 37 | +] |
| 38 | + |
| 39 | + |
31 | 40 | class MPSOperatorSupport(OperatorSupportBase):
|
32 | 41 | def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
|
33 | 42 | self.node_visitors = get_node_visitors(edge_program)
|
@@ -65,10 +74,47 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
|
65 | 74 | op_support=self.supported_ops,
|
66 | 75 | )
|
67 | 76 |
|
| 77 | + def mps_graph_advanced_indexing_support(self, node: torch.fx.Node): |
| 78 | + num_indices = 0 |
| 79 | + tensors = cast(List[torch.fx.Node], node.args[1]) |
| 80 | + input = cast(torch.fx.Node, node.args[0]) |
| 81 | + for t in tensors: |
| 82 | + if t is not None: |
| 83 | + num_indices += 1 |
| 84 | + # Can dispatch to MPSGraph if the length of the slices is equal |
| 85 | + # to the number of dimensions of the sliced tensors, or only one |
| 86 | + # slice is present. All other cases will fallback to a Metal kernel. |
| 87 | + if num_indices == len(get_shape(input)) or num_indices == 1: |
| 88 | + return True |
| 89 | + |
| 90 | + return False |
| 91 | + |
| 92 | + def use_metal_kernel(self, node: torch.fx.Node): |
| 93 | + if node.target in METAL_KERNELS: |
| 94 | + if ( |
| 95 | + node.target == exir_ops.edge.aten.index.Tensor |
| 96 | + or node.target == exir_ops.edge.aten.index_put.default |
| 97 | + ): |
| 98 | + if not self.mps_graph_advanced_indexing_support(node): |
| 99 | + return True |
| 100 | + return False |
| 101 | + |
68 | 102 | def tag_nodes(self, partitions: List[Partition]) -> None:
|
69 | 103 | for partition in partitions:
|
70 |
| - for node in partition.nodes: |
| 104 | + crt_partition_counter = 0 |
| 105 | + for node in sorted(partition.nodes): |
71 | 106 | delegation_tag = f"mps_{partition.id}"
|
| 107 | + if self.use_metal_kernel(node): |
| 108 | + logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!") |
| 109 | + # Partition the Metal kernel into a separate partition |
| 110 | + crt_partition_counter += 1 |
| 111 | + delegation_tag = ( |
| 112 | + f"{delegation_tag}_metal_kernel_{crt_partition_counter}" |
| 113 | + ) |
| 114 | + crt_partition_counter += 1 |
| 115 | + else: |
| 116 | + delegation_tag = f"{delegation_tag}_{crt_partition_counter}" |
| 117 | + |
72 | 118 | node.meta["delegation_tag"] = delegation_tag
|
73 | 119 | self.partition_tags[delegation_tag] = self.delegation_spec
|
74 | 120 |
|
|
0 commit comments