Skip to content

Add CoreML support for to_edge_transform_and_lower #8505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import logging
from typing import List, Optional
from typing import Callable, List, Optional, Tuple

import coremltools as ct

Expand Down Expand Up @@ -104,3 +104,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
do_not_decompose = []
op_support = OperatorsSupportedForCoreMLBackend()
for node in ep.graph.nodes:
if (
node.op == "call_function"
and isinstance(node.target, torch._ops.OpOverload)
and op_support.is_node_supported(None, node)
):
do_not_decompose.append(node.target)
return do_not_decompose, None
46 changes: 46 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir.backend.utils import format_delegated_graph


class TestCoreMLPartitioner(unittest.TestCase):
Expand Down Expand Up @@ -79,6 +80,50 @@ def test_vit_skip_conv(self):
"getitem",
]

def test_ops_to_not_decompose(self):
class Model(torch.nn.Module):
def forward(self, q, k, v, mask):
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=mask
)

model = Model()
model.eval()

batch_size = 1
n_heads = 12
seq_len = 1
max_seq_length = 32
embedding_dim = 16
q = torch.randn(batch_size, n_heads, seq_len, embedding_dim)
k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
mask = torch.randn(seq_len, max_seq_length)
example_inputs = (q, k, v, mask)
ep = torch.export.export(model, example_inputs)
coreml_partitioner = CoreMLPartitioner()

# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
ep, partitioner=[coreml_partitioner]
)
self.assertTrue(
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
in format_delegated_graph(
edge_program_manager.exported_program().graph_module
)
)

# Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph
edge_program_manager2 = executorch.exir.to_edge(ep)
edge_program_manager2.to_backend(coreml_partitioner)
self.assertTrue(
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
not in format_delegated_graph(
edge_program_manager2.exported_program().graph_module
)
)

def test_buffer(self):
embedding_dim = 3
max_seq_len = 2
Expand Down Expand Up @@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):
test_runner = TestCoreMLPartitioner()
test_runner.test_add_sub_skip_mm()
test_runner.test_vit_skip_conv()
test_runner.test_ops_to_not_decompose()
test_runner.test_buffer()
Loading