Skip to content

Commit 38720f4

Browse files
committed
Update
[ghstack-poisoned]
2 parents beecf2a + a01571f commit 38720f4

File tree

134 files changed

+2846
-574
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

134 files changed

+2846
-574
lines changed

CMakeLists.txt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR "Build the Flat Tensor extension"
186186
OFF
187187
)
188188

189+
option(EXECUTORCH_BUILD_EXTENSION_LLM "Build the LLM extension"
190+
OFF
191+
)
192+
189193
option(EXECUTORCH_BUILD_EXTENSION_MODULE "Build the Module extension" OFF)
190194

191195
option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"
@@ -245,7 +249,7 @@ cmake_dependent_option(
245249
)
246250

247251
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
248-
set(EXECUTORCH_BUILF_EXTENSION_DATA_LOADER ON)
252+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
249253
endif()
250254

251255
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
@@ -348,6 +352,7 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
348352
endif()
349353

350354
if(EXECUTORCH_BUILD_TESTS)
355+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
351356
include(CTest)
352357
endif()
353358

@@ -373,7 +378,7 @@ if(NOT "${_repo_dir_name}" STREQUAL "executorch")
373378
"fix for this restriction."
374379
)
375380
endif()
376-
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR}/runtime/core/portable_type)
381+
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR}/runtime/core/portable_type/c10)
377382

378383
#
379384
# The `_<target>_srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}.
@@ -717,6 +722,10 @@ if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
717722
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor/serialize)
718723
endif()
719724

725+
if(EXECUTORCH_BUILD_EXTENSION_LLM)
726+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizer)
727+
endif()
728+
720729
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
721730
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
722731
endif()

backends/apple/coreml/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ target_include_directories(
134134
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/util
135135
)
136136
target_include_directories(coremldelegate PRIVATE ${EXECUTORCH_ROOT}/..)
137-
target_include_directories(coremldelegate PRIVATE ${EXECUTORCH_ROOT}/runtime/core/portable_type)
137+
target_include_directories(coremldelegate PRIVATE ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
138138
target_compile_definitions(coremldelegate PRIVATE C10_USING_CUSTOM_GENERATED_MACROS)
139139
target_link_libraries(coremldelegate PRIVATE executorch_core)
140140

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import logging
6-
from typing import List, Optional
6+
from typing import Callable, List, Optional, Tuple
77

88
import coremltools as ct
99

@@ -104,3 +104,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
104104
return PartitionResult(
105105
tagged_exported_program=exported_program, partition_tags=partition_tags
106106
)
107+
108+
def ops_to_not_decompose(
109+
self, ep: ExportedProgram
110+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111+
do_not_decompose = []
112+
op_support = OperatorsSupportedForCoreMLBackend()
113+
for node in ep.graph.nodes:
114+
if (
115+
node.op == "call_function"
116+
and isinstance(node.target, torch._ops.OpOverload)
117+
and op_support.is_node_supported(None, node)
118+
):
119+
do_not_decompose.append(node.target)
120+
return do_not_decompose, None

backends/apple/coreml/runtime/workspace/executorchcoreml.xcodeproj/project.pbxproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@
922922
"$(SRCROOT)/../kvstore",
923923
"$(SRCROOT)/../inmemoryfs",
924924
"$(SRCROOT)/../include",
925-
"$(SRCROOT)/../include/executorch/runtime/core/portable_type",
925+
"$(SRCROOT)/../include/executorch/runtime/core/portable_type/c10",
926926
"$(SRCROOT)/../sdk",
927927
"$(SRCROOT)/../util",
928928
"$(SRCROOT)/../../third-party/nlohmann_json/single_include",
@@ -954,7 +954,7 @@
954954
"$(SRCROOT)/../kvstore",
955955
"$(SRCROOT)/../inmemoryfs",
956956
"$(SRCROOT)/../include",
957-
"$(SRCROOT)/../include/executorch/runtime/core/portable_type",
957+
"$(SRCROOT)/../include/executorch/runtime/core/portable_type/c10",
958958
"$(SRCROOT)/../sdk",
959959
"$(SRCROOT)/../util",
960960
"$(SRCROOT)/../../third-party/nlohmann_json/single_include",

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1515
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
16+
from executorch.exir.backend.utils import format_delegated_graph
1617

1718

1819
class TestCoreMLPartitioner(unittest.TestCase):
@@ -79,6 +80,50 @@ def test_vit_skip_conv(self):
7980
"getitem",
8081
]
8182

83+
def test_ops_to_not_decompose(self):
84+
class Model(torch.nn.Module):
85+
def forward(self, q, k, v, mask):
86+
return torch.ops.aten.scaled_dot_product_attention.default(
87+
q, k, v, attn_mask=mask
88+
)
89+
90+
model = Model()
91+
model.eval()
92+
93+
batch_size = 1
94+
n_heads = 12
95+
seq_len = 1
96+
max_seq_length = 32
97+
embedding_dim = 16
98+
q = torch.randn(batch_size, n_heads, seq_len, embedding_dim)
99+
k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
100+
v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)
101+
mask = torch.randn(seq_len, max_seq_length)
102+
example_inputs = (q, k, v, mask)
103+
ep = torch.export.export(model, example_inputs)
104+
coreml_partitioner = CoreMLPartitioner()
105+
106+
# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph
107+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
108+
ep, partitioner=[coreml_partitioner]
109+
)
110+
self.assertTrue(
111+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
112+
in format_delegated_graph(
113+
edge_program_manager.exported_program().graph_module
114+
)
115+
)
116+
117+
# Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph
118+
edge_program_manager2 = executorch.exir.to_edge(ep)
119+
edge_program_manager2.to_backend(coreml_partitioner)
120+
self.assertTrue(
121+
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
122+
not in format_delegated_graph(
123+
edge_program_manager2.exported_program().graph_module
124+
)
125+
)
126+
82127
def test_buffer(self):
83128
embedding_dim = 3
84129
max_seq_len = 2
@@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):
129174
test_runner = TestCoreMLPartitioner()
130175
test_runner.test_add_sub_skip_mm()
131176
test_runner.test_vit_skip_conv()
177+
test_runner.test_ops_to_not_decompose()
132178
test_runner.test_buffer()

backends/arm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ endif()
1414

1515
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
1616

17-
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type)
17+
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
1818
add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS)
1919

2020
# Third-party folder and Ethos-U driver inclued

backends/arm/TARGETS

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
44
python_library(
55
name = "arm_partitioner",
66
srcs = [
7-
"arm_partitioner.py",
7+
"ethosu_backend.py",
8+
"ethosu_partitioner.py",
9+
"tosa_backend.py",
10+
"tosa_partitioner.py",
811
],
912
typing = True,
1013
deps = [

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777
)
7878
from executorch.backends.arm.tosa_specification import TosaSpecification
7979

80+
from executorch.backends.transforms.replace_scalar_with_tensor import (
81+
ReplaceScalarWithTensorArgPass,
82+
)
8083
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
8184
from executorch.exir import ExportedProgram
8285
from executorch.exir.pass_manager import PassManager
@@ -102,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
102105
self.add_pass(ConvertMeanDimToAveragePoolPass())
103106
self.add_pass(ConvertFullLikeToFullPass())
104107

108+
self.add_pass(ReplaceScalarWithTensorArgPass())
105109
self.add_pass(AnnotateDecomposedMatmulPass())
106110
self.add_pass(QuantizeOperatorArguments())
107111
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -125,7 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125129
return self._transform(exported_program.graph_module)
126130

127131
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
128-
132+
self.add_pass(ReplaceScalarWithTensorArgPass())
129133
self.add_pass(FuseQuantizedActivationPass())
130134
self.add_pass(RemoveGetItemPass())
131135
self.add_pass(ConvertSplitToSlicePass())
@@ -176,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
176180

177181
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
178182
self.add_pass(ScalarsToAttributePass())
183+
self.add_pass(ReplaceScalarWithTensorArgPass())
179184
self.add_pass(DecomposeLayerNormPass())
180185
self.add_pass(DecomposeVarPass())
181186
self.add_pass(DecomposeMeanDimPass())

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def try_set_param(
114114
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115115
bn_bias_node, fused_conv_bias
116116
):
117+
# pyre-ignore[60]
117118
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
118119
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
119120
conv.args = conv_args

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,17 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676
new_args.append(get_attr_node)
7777
n.args = tuple(new_args)
7878

79+
# Replace rsub.Scalar with sub.Tensor as retracing will fail otherwise
80+
if n.target == torch.ops.aten.rsub.Scalar:
81+
with graph_module.graph.inserting_after(n):
82+
reversed_args = (n.args[1], n.args[0])
83+
sub = graph_module.graph.create_node(
84+
"call_function", torch.ops.aten.sub.Tensor, reversed_args, {}
85+
)
86+
n.replace_all_uses_with(sub)
87+
sub.meta["val"] = n.meta["val"]
88+
graph_module.graph.erase_node(n)
89+
7990
graph_module.recompile()
91+
graph_module = super().call(graph_module).graph_module
8092
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)