Skip to content

Commit a101ea5

Browse files
committed
Update base for Update on "[ET-VK] Replace Uniform buffers with push constants for native layer norm op"
This diff replaces Uniform buffers with push constants for the native layer norm op in the Vulkan backend of Executorch. The changes include updating the shader code to use push constants instead of Uniform buffers, and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D70943355](https://our.internmc.facebook.com/intern/diff/D70943355/) [ghstack-poisoned]
2 parents 7622f74 + cc73d03 commit a101ea5

File tree

105 files changed

+2013
-819
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+2013
-819
lines changed

.github/scripts/extract_benchmark_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def transform(
360360
"app_type": app_type,
361361
# Just keep a copy of the benchmark config here
362362
"benchmark_config": json.dumps(benchmark_config),
363+
"job_conclusion": "SUCCESS",
363364
},
364365
},
365366
"model": {
@@ -455,7 +456,7 @@ def transform_failure_record(
455456
},
456457
"metric": {
457458
"name": "FAILURE_REPORT",
458-
"benchmark_values": 0,
459+
"benchmark_values": [0],
459460
"target_value": 0,
460461
"extra_info": {
461462
"method": "",

CMakeLists.txt

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,18 @@ target_link_options_shared_lib(executorch)
645645
# Real integrations should supply their own YAML file that only lists the
646646
# operators necessary for the models that will run.
647647
#
648+
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
649+
# find pytorch lib here to make it available to all
650+
# sub-directories. Find it before including portable so that
651+
# optimized_portable_kernels can use it.
652+
find_package_torch_headers()
653+
endif()
654+
648655
if(BUILD_EXECUTORCH_PORTABLE_OPS)
649656
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/portable)
650657
endif()
651658

652659
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
653-
# find pytorch lib here to make it available to all sub-directories
654-
find_package_torch_headers()
655660
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/optimized)
656661
endif()
657662

@@ -764,10 +769,6 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
764769
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
765770
endif()
766771

767-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
768-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
769-
endif()
770-
771772
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
772773
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
773774
endif()
@@ -872,34 +873,13 @@ if(EXECUTORCH_BUILD_PYBIND)
872873

873874
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
874875

875-
set(_pybind_training_dep_libs
876-
${TORCH_PYTHON_LIBRARY}
877-
etdump
878-
executorch
879-
util
880-
torch
881-
extension_training
882-
)
883-
884-
if(EXECUTORCH_BUILD_XNNPACK)
885-
# need to explicitly specify XNNPACK and microkernels-prod
886-
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
887-
list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
888-
endif()
889-
890-
# pybind training
891-
pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
892-
893-
target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS})
894-
target_compile_options(_training_lib PUBLIC ${_pybind_compile_options})
895-
target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
896-
897-
install(TARGETS _training_lib
898-
LIBRARY DESTINATION executorch/extension/training/pybindings
899-
)
900876
endif()
901877
endif()
902878

879+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
880+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
881+
endif()
882+
903883
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
904884
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
905885
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ SCRIPT_DIR_PATH="$(
1212

1313
# TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner.
1414
# Keep this version in sync with: pyproject.toml
15-
COREMLTOOLS_VERSION="8.1"
15+
COREMLTOOLS_VERSION="8.2"
1616

1717
red=`tput setaf 1`
1818
green=`tput setaf 2`

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
35+
pre_autograd_aten_dialect = export_for_training(
36+
model, example_inputs, strict=True
37+
).module()
3638

3739
quantization_config = LinearQuantizerConfig.from_dict(
3840
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def lower_module_and_test_output(
207207
expected_output = model(*sample_inputs)
208208

209209
model = torch.export.export_for_training(
210-
model, sample_inputs, dynamic_shapes=dynamic_shapes
210+
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

213213
edge_program = export_to_edge(

backends/arm/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .arm_pass import ArmPass # noqa
1011
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1112
from .cast_to_int32_pass import CastToInt32Pass # noqa
1213
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -41,6 +42,10 @@
4142
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4243
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4344
from .remove_clone_pass import RemoveClonePass # noqa
45+
from .replace_scalar_with_tensor_pass import ( # noqa
46+
ReplaceScalarWithTensorArgPassTOSABI,
47+
ReplaceScalarWithTensorArgPassTOSAMI,
48+
)
4449
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
4550
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
4651
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa

backends/arm/_passes/arm_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import traceback
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class ArmPass(ExportPass):
16+
"""Base class for Arm passes"""
17+
18+
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
19+
super(ArmPass, self).__init__()
20+
self.exported_program = exported_program
21+
22+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
23+
if not updated:
24+
return super().call_operator(op, args, kwargs, meta)
25+
26+
# if updated we should update metadata
27+
new_meta = {}
28+
keys = meta.data.keys()
29+
for key in keys:
30+
new_meta[key] = meta[key]
31+
old_stack_trace = new_meta.get("stack_trace", "")
32+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
33+
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@
4242
MatchArgRanksPass,
4343
QuantizeOperatorArguments,
4444
RemoveClonePass,
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4547
RetraceFoldedDtypesPass,
4648
ScalarsToAttributePass,
4749
SizeAdjustConv2DPass,
4850
UnsqueezeBeforeRepeatPass,
4951
UnsqueezeScalarPlaceholdersPass,
5052
)
53+
5154
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
5255
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
53-
54-
from executorch.backends.transforms.replace_scalar_with_tensor import (
55-
ReplaceScalarWithTensorArgPass,
56-
)
5756
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5857
from executorch.exir import ExportedProgram
5958
from executorch.exir.pass_manager import PassManager
@@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8483
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8584
self.add_pass(CastToInt32Pass())
8685

87-
self.add_pass(ReplaceScalarWithTensorArgPass())
86+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
8887
self.add_pass(AnnotateDecomposedMatmulPass())
8988
self.add_pass(QuantizeOperatorArguments())
9089
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113112
return self._transform(exported_program.graph_module)
114113

115114
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
116-
self.add_pass(ReplaceScalarWithTensorArgPass())
115+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
117116
self.add_pass(FuseQuantizedActivationPass())
118117
self.add_pass(RemoveGetItemPass())
119118
self.add_pass(ConvertSplitToSlicePass())
@@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170169
)
171170

172171
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
173-
self.add_pass(ReplaceScalarWithTensorArgPass())
172+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174173
self.add_pass(ScalarsToAttributePass())
175174
self.add_pass(DecomposeLayerNormPass())
176175
self.add_pass(DecomposeVarPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
# pyre-unsafe
99

10+
import traceback
1011
from inspect import isclass
1112
from typing import Optional, Sequence
1213

1314
import torch
1415
import torch.fx
15-
1616
from executorch.exir import ExportedProgram
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -96,6 +96,7 @@ def create_node(
9696
kwargs: Optional[dict] = None,
9797
quantize: bool = False,
9898
q_params: Optional[tuple] = None,
99+
from_node: Optional[torch.fx.Node] = None,
99100
):
100101
"""
101102
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,15 +109,26 @@ def create_node(
108109
args=args,
109110
kwargs=kwargs or {},
110111
)
112+
113+
new_meta = {}
114+
if from_node:
115+
keys = from_node.meta.keys()
116+
for key in keys:
117+
new_meta[key] = from_node.meta[key]
118+
old_stack_trace = new_meta.get("stack_trace", "")
119+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
120+
node.meta = new_meta
121+
111122
if quantize and q_params:
112-
return insert_q_dq_pair(graph, node, q_params)
123+
return insert_q_dq_pair(graph, node, q_params, from_node)
113124
return node
114125

115126

116127
def insert_q_dq_pair(
117128
graph: torch.fx.Graph,
118129
anchor: torch.fx.Node,
119130
q_params: tuple,
131+
from_node: Optional[torch.fx.Node] = None,
120132
):
121133
"""
122134
Inserts a q dq node pair after the node 'anchor'.
@@ -127,13 +139,15 @@ def insert_q_dq_pair(
127139
graph=graph,
128140
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
129141
args=(), # We add the argument last
142+
from_node=from_node if from_node else anchor,
130143
)
131144
q.meta = anchor.meta
132145
with graph.inserting_after(q):
133146
dq = create_node(
134147
graph=graph,
135148
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
136149
args=(q,) + q_params,
150+
from_node=from_node if from_node else anchor,
137151
)
138152
dq.meta = q.meta
139153
anchor.replace_all_uses_with(dq)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import operator
1010

1111
import torch
12+
from executorch.backends.arm._passes import ArmPass
1213
from executorch.backends.arm._passes.arm_pass_utils import create_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1516

1617

1718
def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
4041
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4142

4243

43-
class DecomposeLayerNormPass(ExportPass):
44+
class DecomposeLayerNormPass(ArmPass):
4445
"""
4546
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
4647
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
111112
var_op,
112113
args=(x, dims),
113114
kwargs={"correction": 0, "keepdim": keepdim},
115+
from_node=node,
114116
)
115117
full = create_node(
116118
graph_module.graph,
117119
full_op,
118120
args=(epsilon_reshaped_shape, epsilon),
119121
kwargs={"dtype": dtype},
122+
from_node=node,
123+
)
124+
add0 = create_node(
125+
graph_module.graph, add_op, args=(var, full), from_node=node
126+
)
127+
rsqrt = create_node(
128+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
129+
)
130+
mul0 = create_node(
131+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
120132
)
121-
add0 = create_node(graph_module.graph, add_op, args=(var, full))
122-
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
123-
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
124133
if weights is not None:
125134
weights_reshaped = create_node(
126135
graph_module.graph,
127136
view_op,
128137
args=(weights, weights_reshaped_shape),
138+
from_node=node,
129139
)
130140
mul1 = create_node(
131-
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
141+
graph_module.graph,
142+
mul_op,
143+
args=(
144+
mul0,
145+
weights_reshaped,
146+
),
147+
from_node=node,
132148
)
133149
else:
134150
mul1 = mul0
135151
output = mul1
136152
if bias is not None:
137153
bias_reshaped_shape = weights_reshaped_shape
138154
bias_reshaped = create_node(
139-
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
155+
graph_module.graph,
156+
view_op,
157+
args=(bias, bias_reshaped_shape),
158+
from_node=node,
140159
)
141160
output = create_node(
142-
graph_module.graph, add_op, args=(mul1, bias_reshaped)
161+
graph_module.graph,
162+
add_op,
163+
args=(mul1, bias_reshaped),
164+
from_node=node,
143165
)
144166

145167
users = [user for user in node.users if node != user]

0 commit comments

Comments
 (0)