Skip to content

Partition Mutable Buffer as Core ML State #5165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(
self,
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
compile_specs: Optional[List[CompileSpec]] = None,
take_over_mutable_buffer: Optional[bool] = True,
) -> None:
if skip_ops_for_coreml_delegation is None:
skip_ops_for_coreml_delegation = []
Expand All @@ -69,6 +70,7 @@ def __init__(
backend_id=CoreMLBackend.__name__,
compile_specs=compile_specs if compile_specs is not None else [],
)
self.take_over_mutable_buffer = take_over_mutable_buffer

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
Expand All @@ -89,6 +91,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
if self.take_over_mutable_buffer:
logger.info(
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
"so if your model contains mutable buffer, "
"then you will need MacOS15+/iOS18+ to execute. "
"If you want your mutable buffer model to be compatible with older OS, "
"then please set `take_over_mutable_buffer=False`"
)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
Expand Down
7 changes: 6 additions & 1 deletion backends/apple/coreml/scripts/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party"
mkdir "$COREML_DIR_PATH/third-party"

echo "${green}ExecuTorch: Cloning coremltools."
git clone --depth 1 --branch 8.0b1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
git clone --depth 1 --branch 8.0b2 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
cd $COREMLTOOLS_DIR_PATH

STATUS=$?
Expand All @@ -47,6 +47,11 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel

echo "${green}ExecuTorch: Installing coremltools."
pip install "$COREMLTOOLS_DIR_PATH"
# CoreMLTools have started supporting numpy 2.0,
# but ExecuTorch example model test env is still using older transformers,
# so for now we will need to downgrade numpy to 1.x
# TODO: Remove this numpy downgrade once later transformers starts to be used
pip install numpy==1.26.4
STATUS=$?
if [ $STATUS -ne 0 ]; then
echo "${red}ExecuTorch: Failed to install coremltools."
Expand Down
49 changes: 49 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import unittest

import coremltools as ct

import executorch.exir

import torch
import torchvision

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner


Expand Down Expand Up @@ -86,8 +89,54 @@ def test_vit_skip_conv(self):
if node.op == "call_function"
] == total

def test_buffer(self):
embedding_dim = 3
max_seq_len = 2

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"cache",
torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32),
)

def forward(self, q, k_val, input_pos):
q_T = q.transpose(0, 1)
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
attn = k.mm(q_T)
return attn

model = Model()
model.eval()

q = torch.randn((1, embedding_dim))
k_val = torch.randn((1, embedding_dim))
input_pos = torch.tensor([0])
example_inputs = (q, k_val, input_pos)
exir_program_aten = torch.export.export(model, example_inputs)

compile_specs = CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS18
)
partitioner = CoreMLPartitioner(compile_specs=compile_specs)
edge_program_manager = executorch.exir.to_edge(
exir_program_aten, compile_config=self.edge_compile_config
)
delegated_program_manager = edge_program_manager.to_backend(partitioner)

assert [
node.target.__name__
for node in delegated_program_manager.exported_program().graph.nodes
if node.op == "call_function"
] == [
"executorch_call_delegate",
"getitem",
]


if __name__ == "__main__":
test_runner = TestCoreMLPartitioner()
test_runner.test_add_sub_skip_mm()
test_runner.test_vit_skip_conv()
test_runner.test_buffer()
9 changes: 8 additions & 1 deletion examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--mps", action="store_true")
parser.add_argument("--coreml", action="store_true")
parser.add_argument(
"--coreml-enable-state",
action="store_true",
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
)
parser.add_argument(
"--qnn",
action="store_true",
Expand Down Expand Up @@ -504,7 +509,9 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901

if args.coreml:
coreml_partitioner = get_coreml_partitioner(
args.use_kv_cache, args.pt2e_quantize
args.use_kv_cache and args.coreml_enable_state,
args.embedding_quantize,
args.pt2e_quantize,
)
partitioners.append(coreml_partitioner)
modelname = f"coreml_{modelname}"
Expand Down
34 changes: 34 additions & 0 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,40 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
node.meta["delegation_tag"] = user_tags.pop()


def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
"""
Util function for partitioners. This function tags the mutated buffer nodes
whose users all belong within the same partition. This should be called after tagging all other nodes.
Any buffer which is used as input to a subgraph, will be tagged with the same tag as that
subgraph. Throw error when buffers is used across different partitions. That is the
underlying data will be owned by multiple delegates.
"""
for node in edge_program.graph.nodes:
# Determine whether this node is a mutated buffer
is_mutated_buffer_node = False
if node.op == "placeholder" and is_buffer(edge_program, node):
for node_user in node.users:
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
is_mutated_buffer_node = True
break
# This node is mutated buffer, tag it
if is_mutated_buffer_node:
user_tags = set()
for user in node.users:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()


# TODO - style: use templated types
class DelegateMappingBuilder:
"""
Expand Down
30 changes: 15 additions & 15 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ def get_mps_partitioner(use_kv_cache: bool = False):


def get_coreml_partitioner(
use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None
enable_state: bool = False,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
):
assert (
use_kv_cache is True
), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
try:
import coremltools as ct
from executorch.backends.apple.coreml.compiler import ( # pyre-ignore
Expand All @@ -75,22 +74,22 @@ def get_coreml_partitioner(
)

minimum_deployment_target = ct.target.iOS15
# In Core ML, quantization in introduced in iOS 16
if pt2e_quantize is not None:
# In Core ML, stateful execution is introduced in iOS 18
if enable_state:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, quantization is introduced in iOS 16
if embedding_quantize is not None or pt2e_quantize is not None:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
# In Core ML, 8-bit activation quantization is introduced in iOS 17
if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
if (
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8
) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
# In Core ML, 4-bit weight compression is introduced in iOS 18
if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
if (
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4
) or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, stateful execution is introduced in iOS 18
# TODO (https://github.com/pytorch/executorch/issues/4209)
# For now, since mutable buffer is kept in executorch runtime,
# state is out of place and can be handled by older iOS.
# Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have
# if use_kv_cache:
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=minimum_deployment_target,
Expand All @@ -101,6 +100,7 @@ def get_coreml_partitioner(
)
return CoreMLPartitioner( # pyre-fixme[16]
compile_specs=compile_specs,
take_over_mutable_buffer=enable_state,
)


Expand Down
Loading