Skip to content

Commit 31fc3a8

Browse files
authored
Merge branch 'main' into hardtanh_fusion_g3
2 parents cdd6fa7 + 150cbe1 commit 31fc3a8

File tree

72 files changed

+2339
-1435
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2339
-1435
lines changed

CMakeLists.txt

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -764,10 +764,6 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
764764
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
765765
endif()
766766

767-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
768-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
769-
endif()
770-
771767
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
772768
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
773769
endif()
@@ -872,34 +868,13 @@ if(EXECUTORCH_BUILD_PYBIND)
872868

873869
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
874870

875-
set(_pybind_training_dep_libs
876-
${TORCH_PYTHON_LIBRARY}
877-
etdump
878-
executorch
879-
util
880-
torch
881-
extension_training
882-
)
883-
884-
if(EXECUTORCH_BUILD_XNNPACK)
885-
# need to explicitly specify XNNPACK and microkernels-prod
886-
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
887-
list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
888-
endif()
889-
890-
# pybind training
891-
pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
892-
893-
target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS})
894-
target_compile_options(_training_lib PUBLIC ${_pybind_compile_options})
895-
target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
896-
897-
install(TARGETS _training_lib
898-
LIBRARY DESTINATION executorch/extension/training/pybindings
899-
)
900871
endif()
901872
endif()
902873

874+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
875+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
876+
endif()
877+
903878
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
904879
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
905880
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ SCRIPT_DIR_PATH="$(
1212

1313
# TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner.
1414
# Keep this version in sync with: pyproject.toml
15-
COREMLTOOLS_VERSION="8.1"
15+
COREMLTOOLS_VERSION="8.2"
1616

1717
red=`tput setaf 1`
1818
green=`tput setaf 2`

backends/arm/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .arm_pass import ArmPass # noqa
1011
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1112
from .cast_to_int32_pass import CastToInt32Pass # noqa
1213
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -41,6 +42,10 @@
4142
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4243
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4344
from .remove_clone_pass import RemoveClonePass # noqa
45+
from .replace_scalar_with_tensor_pass import ( # noqa
46+
ReplaceScalarWithTensorArgPassTOSABI,
47+
ReplaceScalarWithTensorArgPassTOSAMI,
48+
)
4449
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
4550
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
4651
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa

backends/arm/_passes/arm_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import traceback
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class ArmPass(ExportPass):
16+
"""Base class for Arm passes"""
17+
18+
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
19+
super(ArmPass, self).__init__()
20+
self.exported_program = exported_program
21+
22+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
23+
if not updated:
24+
return super().call_operator(op, args, kwargs, meta)
25+
26+
# if updated we should update metadata
27+
new_meta = {}
28+
keys = meta.data.keys()
29+
for key in keys:
30+
new_meta[key] = meta[key]
31+
old_stack_trace = new_meta.get("stack_trace", "")
32+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
33+
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@
4242
MatchArgRanksPass,
4343
QuantizeOperatorArguments,
4444
RemoveClonePass,
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4547
RetraceFoldedDtypesPass,
4648
ScalarsToAttributePass,
4749
SizeAdjustConv2DPass,
4850
UnsqueezeBeforeRepeatPass,
4951
UnsqueezeScalarPlaceholdersPass,
5052
)
53+
5154
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
5255
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
53-
54-
from executorch.backends.transforms.replace_scalar_with_tensor import (
55-
ReplaceScalarWithTensorArgPass,
56-
)
5756
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5857
from executorch.exir import ExportedProgram
5958
from executorch.exir.pass_manager import PassManager
@@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8483
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8584
self.add_pass(CastToInt32Pass())
8685

87-
self.add_pass(ReplaceScalarWithTensorArgPass())
86+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
8887
self.add_pass(AnnotateDecomposedMatmulPass())
8988
self.add_pass(QuantizeOperatorArguments())
9089
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113112
return self._transform(exported_program.graph_module)
114113

115114
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
116-
self.add_pass(ReplaceScalarWithTensorArgPass())
115+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
117116
self.add_pass(FuseQuantizedActivationPass())
118117
self.add_pass(RemoveGetItemPass())
119118
self.add_pass(ConvertSplitToSlicePass())
@@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170169
)
171170

172171
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
173-
self.add_pass(ReplaceScalarWithTensorArgPass())
172+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174173
self.add_pass(ScalarsToAttributePass())
175174
self.add_pass(DecomposeLayerNormPass())
176175
self.add_pass(DecomposeVarPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
# pyre-unsafe
99

10+
import traceback
1011
from inspect import isclass
1112
from typing import Optional, Sequence
1213

1314
import torch
1415
import torch.fx
15-
1616
from executorch.exir import ExportedProgram
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -96,6 +96,7 @@ def create_node(
9696
kwargs: Optional[dict] = None,
9797
quantize: bool = False,
9898
q_params: Optional[tuple] = None,
99+
from_node: Optional[torch.fx.Node] = None,
99100
):
100101
"""
101102
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,15 +109,26 @@ def create_node(
108109
args=args,
109110
kwargs=kwargs or {},
110111
)
112+
113+
new_meta = {}
114+
if from_node:
115+
keys = from_node.meta.keys()
116+
for key in keys:
117+
new_meta[key] = from_node.meta[key]
118+
old_stack_trace = new_meta.get("stack_trace", "")
119+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
120+
node.meta = new_meta
121+
111122
if quantize and q_params:
112-
return insert_q_dq_pair(graph, node, q_params)
123+
return insert_q_dq_pair(graph, node, q_params, from_node)
113124
return node
114125

115126

116127
def insert_q_dq_pair(
117128
graph: torch.fx.Graph,
118129
anchor: torch.fx.Node,
119130
q_params: tuple,
131+
from_node: Optional[torch.fx.Node] = None,
120132
):
121133
"""
122134
Inserts a q dq node pair after the node 'anchor'.
@@ -127,13 +139,15 @@ def insert_q_dq_pair(
127139
graph=graph,
128140
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
129141
args=(), # We add the argument last
142+
from_node=from_node if from_node else anchor,
130143
)
131144
q.meta = anchor.meta
132145
with graph.inserting_after(q):
133146
dq = create_node(
134147
graph=graph,
135148
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
136149
args=(q,) + q_params,
150+
from_node=from_node if from_node else anchor,
137151
)
138152
dq.meta = q.meta
139153
anchor.replace_all_uses_with(dq)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import operator
1010

1111
import torch
12+
from executorch.backends.arm._passes import ArmPass
1213
from executorch.backends.arm._passes.arm_pass_utils import create_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1516

1617

1718
def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
4041
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4142

4243

43-
class DecomposeLayerNormPass(ExportPass):
44+
class DecomposeLayerNormPass(ArmPass):
4445
"""
4546
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
4647
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
111112
var_op,
112113
args=(x, dims),
113114
kwargs={"correction": 0, "keepdim": keepdim},
115+
from_node=node,
114116
)
115117
full = create_node(
116118
graph_module.graph,
117119
full_op,
118120
args=(epsilon_reshaped_shape, epsilon),
119121
kwargs={"dtype": dtype},
122+
from_node=node,
123+
)
124+
add0 = create_node(
125+
graph_module.graph, add_op, args=(var, full), from_node=node
126+
)
127+
rsqrt = create_node(
128+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
129+
)
130+
mul0 = create_node(
131+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
120132
)
121-
add0 = create_node(graph_module.graph, add_op, args=(var, full))
122-
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
123-
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
124133
if weights is not None:
125134
weights_reshaped = create_node(
126135
graph_module.graph,
127136
view_op,
128137
args=(weights, weights_reshaped_shape),
138+
from_node=node,
129139
)
130140
mul1 = create_node(
131-
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
141+
graph_module.graph,
142+
mul_op,
143+
args=(
144+
mul0,
145+
weights_reshaped,
146+
),
147+
from_node=node,
132148
)
133149
else:
134150
mul1 = mul0
135151
output = mul1
136152
if bias is not None:
137153
bias_reshaped_shape = weights_reshaped_shape
138154
bias_reshaped = create_node(
139-
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
155+
graph_module.graph,
156+
view_op,
157+
args=(bias, bias_reshaped_shape),
158+
from_node=node,
140159
)
141160
output = create_node(
142-
graph_module.graph, add_op, args=(mul1, bias_reshaped)
161+
graph_module.graph,
162+
add_op,
163+
args=(mul1, bias_reshaped),
164+
from_node=node,
143165
)
144166

145167
users = [user for user in node.users if node != user]

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -7,9 +7,9 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes import ArmPass
1011
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
1313

1414

1515
def get_meandim_decomposition(op) -> tuple:
@@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
2828
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
2929

3030

31-
class DecomposeMeanDimPass(ExportPass):
31+
class DecomposeMeanDimPass(ArmPass):
3232
"""
3333
This pass decomposes meandim into a sum and mul node.
3434
@@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):
6262

6363
sum_op, full_op, mul_op = get_meandim_decomposition(op)
6464

65-
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
65+
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
6666
full = super().call_operator(
67-
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
67+
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
6868
)
69-
return super().call_operator(mul_op, (sum, full), {}, meta)
69+
return super().call_operator(mul_op, (sum, full), {}, meta, True)

backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes import ArmPass
910
from executorch.exir.dialects._ops import ops as exir_ops
10-
from executorch.exir.pass_base import ExportPass
1111

1212
# For BI case
1313
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
@@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
4545
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
4646

4747

48-
class DecomposeSoftmaxUnstablePass(ExportPass):
48+
class DecomposeSoftmaxUnstablePass(ArmPass):
4949
"""
5050
This pass decomposes log softmax or softmax into more primitive ops.
5151
@@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
6666
_input = args[0]
6767
dim = [args[1]]
6868

69-
op1 = super().call_operator(exp_op, (_input,), {}, meta)
70-
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
71-
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
72-
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
69+
op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
70+
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
71+
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
72+
op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
7373
if op in log_softmax:
74-
op4 = super().call_operator(log_op, (op4,), {}, meta)
74+
op4 = super().call_operator(log_op, (op4,), {}, meta, True)
7575
return op4

0 commit comments

Comments
 (0)