4
4
#
5
5
6
6
import logging
7
- from typing import cast , Any , Dict , List , Union
7
+ from typing import Any , cast , Dict , List , Union
8
8
9
9
import torch
10
10
from executorch .backends .apple .mps .mps_preprocess import MPSBackend
11
11
from executorch .backends .apple .mps .operators .node_visitor import get_node_visitors
12
12
from executorch .backends .apple .mps .utils .mps_utils import is_parameter
13
+ from executorch .backends .transforms import get_shape
13
14
from executorch .exir .backend .backend_details import CompileSpec
14
15
from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
15
16
generate_partitions_from_list_of_nodes ,
20
21
PartitionResult ,
21
22
)
22
23
from executorch .exir .backend .utils import tag_constant_data
24
+ from executorch .exir .dialects ._ops import ops as exir_ops
23
25
from torch .export .exported_program import ExportedProgram
24
26
from torch .fx .passes .infra .partitioner import Partition
25
27
from torch .fx .passes .operator_support import OperatorSupportBase
26
- from executorch .exir .dialects ._ops import ops as exir_ops
27
- from executorch .backends .transforms import get_shape
28
28
29
29
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
30
30
logging .basicConfig (level = logging .DEBUG , format = FORMAT )
36
36
exir_ops .edge .aten .index_put .default ,
37
37
]
38
38
39
+
39
40
class MPSOperatorSupport (OperatorSupportBase ):
40
41
def __init__ (self , edge_program : torch .export .ExportedProgram , compiler_specs ):
41
42
self .node_visitors = get_node_visitors (edge_program )
@@ -90,7 +91,10 @@ def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
90
91
91
92
def use_metal_kernel (self , node : torch .fx .Node ):
92
93
if node .target in METAL_KERNELS :
93
- if node .target == exir_ops .edge .aten .index .Tensor or node .target == exir_ops .edge .aten .index_put .default :
94
+ if (
95
+ node .target == exir_ops .edge .aten .index .Tensor
96
+ or node .target == exir_ops .edge .aten .index_put .default
97
+ ):
94
98
if not self .mps_graph_advanced_indexing_support (node ):
95
99
return True
96
100
return False
@@ -104,7 +108,9 @@ def tag_nodes(self, partitions: List[Partition]) -> None:
104
108
logging .warning (f"[WARNING] Using Metal kernel for op { node .name } !" )
105
109
# Partition the Metal kernel into a separate partition
106
110
crt_partition_counter += 1
107
- delegation_tag = f"{ delegation_tag } _metal_kernel_{ crt_partition_counter } "
111
+ delegation_tag = (
112
+ f"{ delegation_tag } _metal_kernel_{ crt_partition_counter } "
113
+ )
108
114
crt_partition_counter += 1
109
115
else :
110
116
delegation_tag = f"{ delegation_tag } _{ crt_partition_counter } "
0 commit comments