Skip to content

Commit bb83697

Browse files
authored
Merge branch 'main' into add-module-execute-input-count-check
2 parents d7aa520 + d069d65 commit bb83697

File tree

111 files changed

+2299
-1104
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+2299
-1104
lines changed

CMakeLists.txt

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,33 @@ project(executorch)
4848
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
4949

5050
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
52+
include(CMakeDependentOption)
53+
include(ExternalProject)
5154

5255
if(NOT CMAKE_CXX_STANDARD)
5356
set(CMAKE_CXX_STANDARD 17)
5457
endif()
5558
announce_configured_options(CMAKE_CXX_STANDARD)
5659

60+
if(NOT CMAKE_SYSTEM_PROCESSOR)
61+
set(CMAKE_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR})
62+
endif()
63+
announce_configured_options(CMAKE_SYSTEM_PROCESSOR)
64+
5765
if(NOT CMAKE_BUILD_TYPE)
5866
set(CMAKE_BUILD_TYPE Debug)
5967
endif()
6068
announce_configured_options(CMAKE_BUILD_TYPE)
6169

70+
if(NOT PYTHON_EXECUTABLE)
71+
resolve_python_executable()
72+
endif()
73+
announce_configured_options(PYTHON_EXECUTABLE)
74+
6275
announce_configured_options(CMAKE_CXX_COMPILER_ID)
6376
announce_configured_options(CMAKE_TOOLCHAIN_FILE)
6477
announce_configured_options(BUCK2)
65-
announce_configured_options(PYTHON_EXECUTABLE)
6678

6779
load_build_preset()
6880
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
@@ -72,10 +84,6 @@ print_configured_options()
7284

7385
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
7486

75-
include(tools/cmake/Utils.cmake)
76-
include(CMakeDependentOption)
77-
include(ExternalProject)
78-
7987
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
8088

8189
# Setup RPATH.
@@ -251,11 +259,6 @@ if(EXECUTORCH_BUILD_TESTS)
251259
include(CTest)
252260
endif()
253261

254-
if(NOT PYTHON_EXECUTABLE)
255-
resolve_python_executable()
256-
endif()
257-
message(STATUS "Using python executable '${PYTHON_EXECUTABLE}'")
258-
259262
# TODO(dbort): Fix these warnings and remove this flag.
260263
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
261264

backends/apple/mps/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ endif()
1818

1919
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2020

21-
if(NOT PYTHON_EXECUTABLE)
22-
resolve_python_executable()
23-
endif()
24-
2521
set(_common_compile_options -Wno-deprecated-declarations)
2622
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
2723

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2525
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2626
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
27+
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
2728
from .decompose_linear_pass import DecomposeLinearPass # noqa
2829
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2930
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DecomposeLayerNormPass,
3030
DecomposeLeakyReLUPass,
3131
DecomposeLinearPass,
32+
DecomposeLinearVectorNormPass,
3233
DecomposeMeanDimPass,
3334
DecomposeNotEqualPass,
3435
DecomposeSelectPass,
@@ -86,6 +87,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8687
self.add_pass(ConvertSplitToSlicePass())
8788
self.add_pass(ConvertMmToBmmPass())
8889
self.add_pass(DecomposeLinearPass())
90+
self.add_pass(DecomposeLinearVectorNormPass())
8991
self.add_pass(DecomposeMeanDimPass())
9092
self.add_pass(ConvertFullLikeToFullPass())
9193
self.add_pass(ConvertToClampPass())
@@ -133,6 +135,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133135
self.add_pass(FuseBatchnorm2DPass(exported_program))
134136
self.add_pass(ConvertMmToBmmPass())
135137
self.add_pass(DecomposeLinearPass())
138+
self.add_pass(DecomposeLinearVectorNormPass())
136139
self.add_pass(DecomposeLeakyReLUPass())
137140
self.add_pass(DecomposeBatchNormPass())
138141
self.add_pass(DecomposeLayerNormPass())
@@ -207,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207210
self.add_pass(DecomposeCosineSimilarityPass())
208211
self.add_pass(DecomposeDivPass())
209212
self.add_pass(DecomposeLeakyReLUPass())
213+
self.add_pass(DecomposeLinearVectorNormPass())
210214
self.add_pass(DecomposeSqrtPass())
211215
self.add_pass(DecomposeSiluPass())
212216

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class DecomposeLinearVectorNormPass(ExportPass):
11+
"""
12+
This pass decomposes aten.linalg_vector_norm.default into more primitive ops.
13+
We need to add this pass before quantization for graph annotation.
14+
By default, aten.linalg_vector_norm op is decomposed during legalization to Edge IR.
15+
16+
The decomposition is as follows:
17+
18+
For p == 1:
19+
out = REDUCE_SUM(ABS(x), dims, keepdim)
20+
21+
For p == 2:
22+
out = SQRT(REDUCE_SUM(MUL(x, x), dims, keepdim))
23+
24+
For arbitrary p:
25+
We dont support arbitrary p, because our decomposition looks like
26+
out = POW(REDUCE_SUM(POW(ABS(x), p), dims, keepdim), 1/p)
27+
In this case we need to wrap p into Tensor and we need to know
28+
dtype prior, but we dont know this from FX graph.
29+
"""
30+
31+
torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,)
32+
33+
def call_operator(self, op, args, kwargs, meta):
34+
if op not in self.torch_linalg_vector_norm:
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
# Extract inputs and optional arguments.
38+
# Expected args:
39+
# args[0]: input tensor
40+
# args[1]: norm order 'p' (optional, default: 2.0)
41+
# args[2]: dimensions to reduce (should be provided)
42+
# args[3]: keepdim flag (optional, default: False)
43+
input_tensor = args[0]
44+
norm_order = args[1] if len(args) > 1 else 2.0
45+
norm_dim = args[2] if len(args) > 2 else None
46+
keepdim = args[3] if len(args) > 3 else False
47+
48+
if norm_order not in (1, 2):
49+
raise ValueError(
50+
f"The order of {norm_order}\n"
51+
f"is not supported for linalg_vector_norm operator"
52+
)
53+
54+
if norm_dim is None:
55+
raise ValueError("The norm_dim for linalg_vector_norm is None.")
56+
57+
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
58+
59+
# Decomposition based on norm order.
60+
if norm_order == 1:
61+
op1 = super().call_operator(
62+
torch.ops.aten.abs.default, (input_tensor,), {}, meta
63+
)
64+
op2 = super().call_operator(
65+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
66+
)
67+
return op2
68+
69+
elif norm_order == 2:
70+
# For p == 2, decomposition is sqrt(sum(x * x, dims, keepdim))
71+
op1 = super().call_operator(
72+
torch.ops.aten.mul.Tensor, (input_tensor, input_tensor), {}, meta
73+
)
74+
op2 = super().call_operator(
75+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
76+
)
77+
op3 = super().call_operator(torch.ops.aten.sqrt.default, (op2,), {}, meta)
78+
return op3

backends/arm/operators/op_abs.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from executorch.backends.arm.operators.operator_validation_utils import (
1717
validate_num_inputs,
18+
validate_same_dtype,
1819
)
1920
from executorch.backends.arm.tosa_mapping import TosaArg
2021
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -43,13 +44,8 @@ def define_node(
4344
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4445

4546
validate_num_inputs(self.target, inputs, 1)
46-
# Specification (0.80) states that input and output types
47-
# should all be the same
48-
if not (inputs[0].dtype == output.dtype):
49-
raise ValueError(
50-
"All inputs and outputs need same dtype."
51-
f"Got {inputs[0].dtype=}, {output.dtype=}"
52-
)
47+
validate_same_dtype(self.target, [*inputs, output])
48+
5349
# Handle int8 (quantized) and int32
5450
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
5551
raise ValueError(
@@ -110,13 +106,7 @@ def define_node(
110106
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
111107

112108
validate_num_inputs(self.target, inputs, 1)
113-
# Specification (0.80) states that input and output types
114-
# should all be the same
115-
if not (inputs[0].dtype == output.dtype):
116-
raise ValueError(
117-
"All inputs and output need same dtype."
118-
f"Got {inputs[0].dtype=}, {output.dtype=}"
119-
)
109+
validate_same_dtype(self.target, [*inputs, output])
120110

121111
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
122112
# Call the inherited define_node for handling integers
@@ -163,14 +153,8 @@ def define_node(
163153
import serializer.tosa_serializer as ts # type: ignore
164154

165155
validate_num_inputs(self.target, inputs, 1)
156+
validate_same_dtype(self.target, [*inputs, output])
166157

167-
# Specification (1.0) states that input and output types
168-
# should all be the same
169-
if not (inputs[0].dtype == output.dtype):
170-
raise ValueError(
171-
"All inputs and outputs need same dtype."
172-
f"Got {inputs[0].dtype=}, {output.dtype=}"
173-
)
174158
# Handle int8 (quantized) and int32
175159
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
176160
raise ValueError(
@@ -232,14 +216,7 @@ def define_node(
232216
import serializer.tosa_serializer as ts # type: ignore
233217

234218
validate_num_inputs(self.target, inputs, 1)
235-
236-
# Specification (1.0) states that input and output types
237-
# should all be the same
238-
if not (inputs[0].dtype == output.dtype):
239-
raise ValueError(
240-
"All inputs and output need same dtype."
241-
f"Got {inputs[0].dtype=}, {output.dtype=}"
242-
)
219+
validate_same_dtype(self.target, [*inputs, output])
243220

244221
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
245222
# Call the inherited define_node for handling integers

backends/arm/operators/op_add.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.backends.arm.operators.operator_validation_utils import (
1818
validate_num_inputs,
19+
validate_same_dtype,
1920
)
2021
from executorch.backends.arm.tosa_mapping import TosaArg
2122
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -44,14 +45,8 @@ def define_node(
4445
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4546

4647
validate_num_inputs(self.target, inputs, 2)
47-
# Specification (0.80) states that input and output types
48-
# should all be the same
49-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
50-
raise TypeError(
51-
f"All IO needs to have the same data type, got input 1: "
52-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
53-
f"{output.dtype}"
54-
)
48+
validate_same_dtype(self.target, [*inputs, output])
49+
5550
# Handle int8 (quantized) and int32
5651
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
5752
if inputs[0].dtype not in supported_dtypes:
@@ -123,14 +118,7 @@ def define_node(
123118
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
124119

125120
validate_num_inputs(self.target, inputs, 2)
126-
# Specification (0.80) states that input and output types
127-
# should all be the same
128-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
129-
raise TypeError(
130-
f"All IO needs to have the same data type, got input 1: "
131-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
132-
f"{output.dtype}"
133-
)
121+
validate_same_dtype(self.target, [*inputs, output])
134122

135123
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
136124
# Call the inherited define_node for handling integers
@@ -175,15 +163,8 @@ def define_node(
175163
import serializer.tosa_serializer as ts # type: ignore
176164

177165
validate_num_inputs(self.target, inputs, 2)
166+
validate_same_dtype(self.target, [*inputs, output])
178167

179-
# Specification (1.0) states that input and output types
180-
# should all be the same
181-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
182-
raise TypeError(
183-
f"All IO needs to have the same data type, got input 1: "
184-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
185-
f"{output.dtype}"
186-
)
187168
# Handle int8 (quantized) and int32
188169
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
189170
if inputs[0].dtype not in supported_dtypes:
@@ -245,15 +226,7 @@ def define_node(
245226
import serializer.tosa_serializer as ts # type: ignore
246227

247228
validate_num_inputs(self.target, inputs, 2)
248-
249-
# Specification (1.0) states that input and output types
250-
# should all be the same
251-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
252-
raise TypeError(
253-
f"All IO needs to have the same data type, got input 1: "
254-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
255-
f"{output.dtype}"
256-
)
229+
validate_same_dtype(self.target, [*inputs, output])
257230

258231
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
259232
# Call the inherited define_node for handling integers

backends/arm/operators/op_amax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
14+
validate_same_dtype,
1415
)
1516
from executorch.backends.arm.tosa_mapping import TosaArg
1617
from torch.fx import Node
@@ -35,6 +36,7 @@ def define_node(
3536
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3637

3738
validate_num_inputs(self.target, inputs, 3)
39+
validate_same_dtype(self.target, [inputs[0], output])
3840

3941
input = inputs[0]
4042
dim = inputs[1].number
@@ -77,6 +79,7 @@ def define_node(
7779
import serializer.tosa_serializer as ts
7880

7981
validate_num_inputs(self.target, inputs, 3)
82+
validate_same_dtype(self.target, [inputs[0], output])
8083

8184
input = inputs[0]
8285
dim = inputs[1].number

backends/arm/operators/op_amin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
14+
validate_same_dtype,
1415
)
1516
from executorch.backends.arm.tosa_mapping import TosaArg
1617
from torch.fx import Node
@@ -35,6 +36,7 @@ def define_node(
3536
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3637

3738
validate_num_inputs(self.target, inputs, 3)
39+
validate_same_dtype(self.target, [inputs[0], output])
3840

3941
input = inputs[0]
4042
dim = inputs[1].number
@@ -77,6 +79,7 @@ def define_node(
7779
import serializer.tosa_serializer as ts
7880

7981
validate_num_inputs(self.target, inputs, 3)
82+
validate_same_dtype(self.target, [inputs[0], output])
8083

8184
input = inputs[0]
8285
dim = inputs[1].number

0 commit comments

Comments
 (0)