diff --git a/.github/workflows/build-presets.yml b/.github/workflows/build-presets.yml new file mode 100644 index 00000000000..39bc9dc6480 --- /dev/null +++ b/.github/workflows/build-presets.yml @@ -0,0 +1,13 @@ +name: Build Presets + +on: + pull_request: + push: + branches: + - main + - release/* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index aac3b300f9b..35879d5026c 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -19,6 +19,7 @@ from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa +from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7ce07d5e73f..d6d608918a6 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -24,6 +24,7 @@ ConvertSqueezesToViewPass, ConvertToClampPass, DecomposeBatchNormPass, + DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeGeluPass, DecomposeLayerNormPass, @@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) self.add_pass(DecomposeNotEqualPass()) + self.add_pass(DecomposeCosineSimilarityPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeSqrtPass()) diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py new file mode 100644 index 00000000000..9978e653408 --- /dev/null +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass + +torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) + + +class DecomposeCosineSimilarityPass(ExportPass): + """ + Decomposition of aten.cosine_similarity: + + dot = sum(mul(x1, x2), dims, keepdim=False) + norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 ) + eps = full( (), eps_scalar ) + n1c = max(norm1, eps) + n2c = max(norm2, eps) + denom = mul(n1c, n2c) + out = div(dot, denom) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in torch_cosine_similarity: + return super().call_operator(op, args, kwargs, meta) + + x1, x2 = args[0], args[1] + dim = kwargs.get("dim", 1) + eps = kwargs.get("eps", 1e-8) + dims = [dim] if isinstance(dim, int) else list(dim) + + # 1) dot + prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta) + dot = super().call_operator( + torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta + ) + + # 2a) norm1 = pow(sum(x1*x1), 0.5) + x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta) + s1 = super().call_operator( + torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta + ) + norm1 = super().call_operator( + torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta + ) + + # 2b) norm2 = pow(sum(x2*x2), 0.5) + x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta) + s2 = super().call_operator( + torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta + ) + norm2 = super().call_operator( + torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta + ) + + # 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar + eps_t = super().call_operator( + torch.ops.aten.full_like.default, (norm1, eps), {}, meta + ) + + # 4) clamp to avoid zero division + n1c = super().call_operator( + torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta + ) + n2c = super().call_operator( + torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta + ) + + # 5) denom and divide + denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta) + out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta) + + return out diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index c7cd5aa79bb..7f97450849d 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -10,6 +10,11 @@ python_library( ], ) +python_library( + name = "operator_validation_utils", + srcs = ["operator_validation_utils.py"], +) + python_library( name = "ops", srcs = glob(["op_*.py", "ops_*.py"]), @@ -17,6 +22,7 @@ python_library( "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", ":node_visitor", + ":operator_validation_utils", "//executorch/backends/arm:tosa_mapping", "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index a6d649fd92e..43929d3b1c8 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -39,6 +42,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -105,6 +109,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -157,6 +162,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + # Specification (1.0) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -224,6 +231,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + # Specification (1.0) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 11c32a3ae5f..fc8ecbb960a 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -118,6 +122,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -169,6 +174,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -237,6 +244,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 6bb9d563ca6..52cfbb18e81 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -9,6 +9,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -31,6 +34,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 5c0fee5cfaf..d9f05c6f9f1 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -9,6 +9,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -31,6 +34,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index efb5b0b72b0..d8be68fbbc1 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore from torch.fx import Node @@ -30,6 +33,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -69,6 +74,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 727fd52dfd5..504de7319a2 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -85,6 +88,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -122,6 +127,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -212,6 +219,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -252,6 +261,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] if inputs[0].dtype not in supported_dtypes: raise TypeError( diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index ebc43ca33f6..8c68bde2006 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -17,7 +17,9 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_specification import TosaSpecification @@ -46,6 +48,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: raise TypeError( f"All IO needs to have the same data type, got: " @@ -128,6 +131,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: raise TypeError( f"All IO needs to have the same data type, got: " diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index bb77ba77940..c7bad9e4429 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [1, 2]) + tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) @@ -68,6 +73,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, [1, 2]) + tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index aedcc643e5d..566121d1bbb 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -15,6 +15,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float: # Attempt to cast to float return float(value) - if len(node.args) != 2 and len(node.args) != 3: - raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}") - min_arg = dtype_min max_arg = dtype_max @@ -87,10 +87,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) min_int8, max_int8 = self._get_min_max_arguments( node, @@ -130,10 +127,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) if inputs[0].dtype == ts.DType.INT8: # Call the inherited define_node for handling integers @@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float: # Attempt to cast to float return float(value) - if len(node.args) != 2 and len(node.args) != 3: - raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}") - min_arg = dtype_min max_arg = dtype_max @@ -202,10 +193,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments min_int8, max_int8 = self._get_min_max_arguments( @@ -247,10 +235,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) min_fp32, max_fp32 = self._get_min_max_arguments( node, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 75cdc0b0fc4..57c13664e76 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) qargs = input_qparams[0] @@ -98,9 +103,10 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) qargs = input_qparams[0] diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 6f91c181bd2..fd35439d64a 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -17,6 +17,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_specification import TosaSpecification @@ -67,6 +70,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input, weight, bias, stride, pad, dilation, _, _, group = inputs + validate_num_inputs(self.target, inputs, 9) # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() @@ -242,6 +246,7 @@ def define_node( from tosa.RoundingMode import RoundingMode # type: ignore input, weight, bias, stride, pad, dilation, _, _, group = inputs + validate_num_inputs(self.target, inputs, 9) # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index 1fee25511ce..43fa26176e5 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -33,10 +36,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index a36a1f1b0cd..4cfa6012145 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator EQ but got " @@ -89,6 +94,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator EQ but got " diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e174069ee77..bfce5c26699 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -63,6 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 60cc727d149..b23973a20a9 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input dtype: " diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index c929f5f9c87..9c4425857f8 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GE but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GE but got " diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 53196a0d03c..638dee7ccfc 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GT but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GT but got " diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index d927c1ba0db..bc7751c90dc 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LE but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LE but got " diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index b08bbcec003..9b4ef4c7b73 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 2e49eda7d98..02ca0d4d263 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LT but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LT but got " diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 928262aefc5..40f48d3896f 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4]) + input_tensor = inputs[0] kernel_size = inputs[1].special stride = inputs[2].special @@ -109,6 +114,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4]) + input_tensor = inputs[0] kernel_size = inputs[1].special stride = inputs[2].special diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 983ac5ded6d..5d5c56b90f8 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -17,7 +17,9 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape @@ -46,6 +48,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -113,6 +117,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.NanPropagationMode import NanPropagationMode # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index f39e2ce6d61..85c9b4ac3ed 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape @@ -44,6 +47,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -111,6 +116,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.NanPropagationMode import NanPropagationMode # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 6c5b94f1a2b..7d84be213b9 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -19,6 +19,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import reshape_for_broadcast @@ -42,6 +45,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if ( inputs[0].dtype != ts.DType.INT8 or inputs[1].dtype != ts.DType.INT8 @@ -122,6 +127,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) @@ -152,6 +159,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if ( inputs[0].dtype != ts.DType.INT8 or inputs[1].dtype != ts.DType.INT8 @@ -218,6 +227,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index f3ea8b00961..b78ee94b774 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -105,6 +108,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) @@ -142,6 +147,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 781fce3c79f..0b9ba6321f7 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -36,6 +39,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -77,6 +82,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 7d1ee951993..d8888ec9d49 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -35,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " @@ -69,10 +70,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 979a10ecff1..1ed42b23b9e 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape @@ -34,6 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + multiples = inputs[1].special attr = ts.TosaSerializerAttribute() @@ -61,6 +66,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + multiples = inputs[1].special if len(multiples) == 0: diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 3c9abe1ba57..52953db24d0 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale @@ -35,6 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 5) + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) @@ -91,6 +96,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore + validate_num_inputs(self.target, inputs, 5) + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 375dd76ba8d..e843f669a58 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00 @@ -32,6 +35,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() round = False if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: @@ -63,6 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() round = False if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions: diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 784c4b4d257..53156e9249a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -35,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " @@ -67,10 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index a43e9ae798f..2881fc02eb5 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index ee444c38f37..e082f6cb7a4 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -33,10 +36,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index a8d326cfa9b..412e3cca922 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -47,6 +50,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [4, 5]) + # See slice_copy_support.py if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") @@ -99,6 +104,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [4, 5]) + # See slice_copy_support.py if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 65126f4d4dc..03c930918d7 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -113,6 +118,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -167,6 +174,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same assert inputs[0].dtype == inputs[1].dtype == output.dtype @@ -228,6 +237,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same assert inputs[0].dtype == inputs[1].dtype == output.dtype diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index b898eb6cb67..f232136fd9b 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) dim_list = [dim % len(input_shape) for dim in dim_list] @@ -98,6 +103,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name @@ -151,6 +158,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) dim_list = [dim % len(input_shape) for dim in dim_list] @@ -210,6 +219,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 454aebecd5e..350403f19bc 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 01af36c4d37..02727d0fabe 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -10,8 +10,12 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification + from torch.fx import Node @@ -34,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index 210bfd2f61f..5dde6828f72 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) @@ -66,4 +71,6 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 740576f2736..d68bee88a64 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) @@ -66,4 +71,6 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index ac98979c234..8b0754fa079 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -37,6 +40,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() @@ -67,6 +72,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 1c0fbc11d24..88149a7be91 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape @@ -36,6 +39,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore + validate_num_inputs(self.target, inputs, 4) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") @@ -123,6 +128,8 @@ def define_node( from tosa.ResizeMode import ResizeMode # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore + validate_num_inputs(self.target, inputs, 4) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index c08896c2cdc..da40859de74 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape @@ -36,6 +39,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") @@ -92,6 +97,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + assert ( inputs[0].shape is not None and output.shape is not None ), "Only static shapes are supported" diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index e8dedb65315..22a8146ecbd 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape @@ -34,6 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() new_shape = tosa_shape(inputs[1].special, output.dim_order) attr.ReshapeAttribute(new_shape) @@ -61,6 +66,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + tosa_graph = cast(ts.TosaSerializer, tosa_graph) if len(output.shape) != 0: diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index b58fda1c399..d34f4134def 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -9,6 +9,10 @@ NodeVisitor, register_node_visitor, ) + +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,8 +38,7 @@ def _add_node_to_tosa_graph( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + validate_num_inputs(self.target, inputs, 3) if inputs[0].dtype is not ts.DType.BOOL: raise ValueError("Input 0 needs to have dtype BOOL") @@ -66,6 +69,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + bi_supported_dtypes = [ ts.DType.INT8, ts.DType.INT16, @@ -94,6 +99,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, @@ -125,8 +132,7 @@ def _add_node_to_tosa_graph( ) -> None: import serializer.tosa_serializer as ts - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + validate_num_inputs(self.target, inputs, 3) if inputs[0].dtype is not ts.DType.BOOL: raise ValueError("Input 0 needs to have dtype BOOL") @@ -157,6 +163,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + bi_supported_dtypes = [ ts.DType.INT8, ts.DType.INT16, @@ -185,6 +193,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py new file mode 100644 index 00000000000..824695b4643 --- /dev/null +++ b/backends/arm/operators/operator_validation_utils.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + + +def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): + """ + Validates the number of inputs provided to an operation against expected values. + + This function checks whether the length of the input list matches the expected + number(s) of inputs. + + Parameters: + ----------- + op_name : str + The name of the operation for which the inputs are being validated. + Used in the error message to provide context. + + inputs : List[TosaArg] + A list of inputs to be validated, where each input is assumed to be an + instance of `TosaArg`. + + expected : int or List[int] + The expected number of inputs. Can be either an integer or a list of integers. + + Raises: + ------- + ValueError + If the number of inputs does not match the expected value(s), a `ValueError` is + raised with a message indicating the operation name and the mismatch in expected + versus provided number of inputs. + + Example: + -------- + # Example usage: + from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + ) + + validate_num_inputs(self.target, inputs, [3, 4]) + + """ + if isinstance(expected, int): + expected = [expected] + if len(inputs) not in expected: + expected_str = ", ".join(map(str, expected)) + raise ValueError( + f"{op_name}: Expected number of input(s) to be " + f"[{expected_str}], got {len(inputs)}" + ) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 425007bab3c..0a2f4419dfb 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -62,6 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 0c41e13d445..cd5fa9956a3 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -37,6 +40,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + # Simply add an identityOp tosa_graph.addOperator( ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] @@ -69,6 +74,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + # Simply add an identityOp tosa_graph.addOperator( ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 3bb2be16585..b7ba2df4277 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -38,6 +41,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -76,6 +81,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index af8bc48cd9c..b755f2bcc48 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -4,6 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# Calling this script with any argument is equal to launching it in +# non-interactive mode. "$#" gives the number of positional arguments. +[ "$#" -eq 0 ] && is_script_interactive=1 || is_script_interactive=0 + RESET='\e[0m' RED='\e[31m' GREEN='\e[32m' @@ -31,8 +35,10 @@ VERBS="Add|Fix|Update|Refactor|Improve|Remove|Change|Implement|Create|Modify|"\ # Remote branch REMOTE=$(git rev-parse --abbrev-ref --symbolic-full-name @{u} 2>/dev/null) - -if [ -z "$REMOTE" ]; then +if [ $is_script_interactive -eq 0 ]; then + # Just use the one commit + COMMITS=$(git rev-list HEAD -n 1) +elif [ -z "$REMOTE" ]; then echo -e "${WARNING} Could not find upstream branch to compare to." echo "Please specify the number of commits you are pushing." echo -n "Enter number of commits to check (default 1): " > /dev/tty @@ -155,14 +161,17 @@ for COMMIT in ${COMMITS}; do if [[ ! "$SUBJECT" =~ ^"Arm backend":\ (${VERBS}) ]]; then echo -e "${WARNING} Subject should start with 'Arm backend: '"\ "followed by an imperative verb." >&2 - echo -n "There are warnings in your commit message. Do you want to"\ - "ignore the warning (y/N): " > /dev/tty - read USER_INPUT < /dev/tty + if [ $is_script_interactive -eq 1 ]; then + echo -n "There are warnings in your commit message. Do you want to"\ + "ignore the warning (y/N): " > /dev/tty - # Check user input for warnings - if [[ ! "$USER_INPUT" =~ ^[Yy]$ ]]; then - FAILED=1 + read USER_INPUT < /dev/tty + + # Check user input for warnings + if [[ ! "$USER_INPUT" =~ ^[Yy]$ ]]; then + FAILED=1 + fi fi fi diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index f258e54ac7f..a802807184c 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -106,7 +106,6 @@ def test_nn_functional_MI(test_data): x_fails = { "normalize": "MLETORCH-852: Support aten.index_put.default", - "cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default", "unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor", "fold": "Int64 input && MLETORCH-827: Support aten.index_put.default", } diff --git a/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py new file mode 100644 index 00000000000..f3fa95ec10c --- /dev/null +++ b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py @@ -0,0 +1,52 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm._passes.decompose_cosine_similarity_pass import ( + DecomposeCosineSimilarityPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor, torch.Tensor] + + +class CosineSimilarityModel(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(2, 3, 4), torch.rand(2, 3, 4)) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6) + + +modules = {"cosine_basic": CosineSimilarityModel()} + + +@common.parametrize("module", modules) +def test_decompose_cosine_similarity_tosa_BI(module): + + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5, + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 3, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_full_like_default": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 2, + "executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1, + } + + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+BI", + ops_before_pass=None, + ops_not_before_pass=None, + ops_after_pass=ops_after_pass, + ops_not_after_pass=None, + pass_list=[DecomposeCosineSimilarityPass], + ) + pipeline.run() diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl index a58a4d3a457..4e3b91e6c49 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl @@ -86,6 +86,9 @@ void main() { const int in_l = out_l * stride - padding; VEC4_T sum = VEC4_T(0); + const int out_c_packed_index = out_c >> 2; + const int out_c_packed_lane = out_c & 0x3; + for (int in_c = c_start; in_c < c_end; ++in_c) { // "k" tracks the kernel's index for our input-kernel computation. // It reads out-of-bound zeros, but trying to avoid them complicates @@ -103,16 +106,16 @@ void main() { // It is possible to further reduce the memory footprint by swapping the // dimensions, using x extent for out_channel, and y for kernel. for (int k = 0; k < kernel_size; k++) { - const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4); + const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c_packed_index); const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); - VEC4_T weight = VEC4_T(weight_texel[out_c % 4]); + VEC4_T weight = VEC4_T(weight_texel[out_c_packed_lane]); const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map); sum = fma(weight, load_texel(t_in, in_pos), sum); } } - const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map); + const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c_packed_index, 0, 0), bias_axis_map); const ivec3 out_lpos = ivec3(out_l, out_c, N); - write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map); + write_texel_lpos(t_out, out_lpos, op(sum + bias[out_c_packed_lane], out_min, out_max), out_axis_map); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 1af34c06819..48256cb2996 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -483,7 +483,7 @@ void add_conv1d_node( weight, /*transposed = */ false, /*storage_type = */ utils::kTexture3D, - /*memory_layout = */ utils::kChannelsPacked); + /*memory_layout = */ utils::kWidthPacked); float out_min_val = 0.0f; float out_max_val = 0.0f; diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 1d8c6e3fdbc..3b78c1a0b84 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -23,10 +23,10 @@ namespace backends { namespace xnnpack { namespace delegate { +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; using executorch::runtime::MemoryAllocator; -using executorch::runtime::NamedDataMap; using executorch::runtime::Result; /* diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index f18e319ac33..68cb4b4d885 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -16,7 +16,7 @@ namespace delegate { using executorch::aten::ScalarType; using executorch::aten::SizesType; using executorch::aten::Tensor; -using executorch::runtime::BackendExecutionContext; +using executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::is_contiguous_dim_order; @@ -95,11 +95,19 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) { Tensor* tensor = &args[ext_id]->toTensor(); externals_[i].data = tensor->mutable_data_ptr(); + executorch::aten::DimOrderType dim_order[kTensorDimensionLimit]; + // Reshape runtime inputs if (i < input_ids_.size()) { size_t num_dims = tensor->dim(); + Error err = + ET_RUNTIME_NAMESPACE::get_dim_order(*tensor, dim_order, num_dims); + ET_CHECK_OR_RETURN_ERROR( + err == Error::Ok, + Internal, + "Failed to retrieve dim order from tensor!"); ET_CHECK_OR_RETURN_ERROR( - is_contiguous_dim_order(tensor->dim_order().data(), tensor->dim()), + is_contiguous_dim_order(dim_order, tensor->dim()), Internal, "Expecting default dim_order but got a non default dim_order tensor for external input %u", i); @@ -220,7 +228,7 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const { expected_output_size, static_cast(num_dim)}; ET_LOG(Debug, "Resizing output tensor to a new shape"); - Error err = resize_tensor(*out_tensor, output_size); + Error err = ET_RUNTIME_NAMESPACE::resize_tensor(*out_tensor, output_size); if (err != Error::Ok) { ET_LOG(Error, "Failed to resize output tensor for XNNExecutor"); return err; diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index b98c902f44f..8131b6b8b2c 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -75,7 +75,7 @@ class XNNExecutor { * Executes the graph using the args prepared at prepare_args(). */ ET_NODISCARD executorch::runtime::Error forward( - executorch::runtime::BackendExecutionContext& context); + executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext& context); /** * Prepares the outputs to be returned by the delegate diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index 1e2f07bd905..9e02d566d99 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -22,19 +22,20 @@ namespace executorch { namespace backends { using executorch::backends::xnnpack::delegate::XNNWeightsCache; +using executorch::ET_RUNTIME_NAMESPACE::Backend; +using executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext; +using executorch::ET_RUNTIME_NAMESPACE::BackendInitContext; +using executorch::ET_RUNTIME_NAMESPACE::CompileSpec; +using executorch::ET_RUNTIME_NAMESPACE::DelegateHandle; +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::ArrayRef; -using executorch::runtime::Backend; -using executorch::runtime::BackendExecutionContext; -using executorch::runtime::BackendInitContext; -using executorch::runtime::CompileSpec; -using executorch::runtime::DelegateHandle; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; -using executorch::runtime::NamedDataMap; using executorch::runtime::Result; -class XnnpackBackend final : public ::executorch::runtime::BackendInterface { +class XnnpackBackend final + : public ::executorch::ET_RUNTIME_NAMESPACE::BackendInterface { public: ~XnnpackBackend() = default; diff --git a/backends/xnnpack/runtime/XNNWeightsCache.cpp b/backends/xnnpack/runtime/XNNWeightsCache.cpp index f2842851d3a..1a230c19976 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.cpp +++ b/backends/xnnpack/runtime/XNNWeightsCache.cpp @@ -19,8 +19,8 @@ namespace backends { namespace xnnpack { namespace delegate { +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::MemoryAllocator; -using executorch::runtime::NamedDataMap; XNNWeightsCache::XNNWeightsCache() { weights_cache_.context = this; diff --git a/backends/xnnpack/runtime/XNNWeightsCache.h b/backends/xnnpack/runtime/XNNWeightsCache.h index bc00ac15fd0..f8371f93d01 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.h +++ b/backends/xnnpack/runtime/XNNWeightsCache.h @@ -23,10 +23,10 @@ namespace backends { namespace xnnpack { namespace delegate { +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; using executorch::runtime::MemoryAllocator; -using executorch::runtime::NamedDataMap; using executorch::runtime::Result; struct PackedDataMeta { diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index e97f1941ff7..aee5104b17a 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -1,5 +1,5 @@ load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_options", "runtime") def _get_preprocessor_flags(): """ @@ -33,40 +33,42 @@ def define_common_targets(): ], ) - runtime.cxx_library( - name = "xnnpack_backend", - srcs = native.glob([ - "runtime/*.cpp", - "runtime/profiling/*.cpp", - ]), - headers = native.glob([ - "runtime/*.h", - "runtime/profiling/*.h", - ]), - visibility = [ - "//executorch/exir/backend:backend_lib", - "//executorch/exir/backend/test/...", - "//executorch/backends/xnnpack/test/...", - "//executorch/extension/pybindings/...", - "@EXECUTORCH_CLIENTS", - ], - preprocessor_flags = [ - # Uncomment to enable per operator timings - # "-DENABLE_XNNPACK_PROFILING", - # Uncomment to enable using KleidiAI Kernels - # "-DENABLE_XNNPACK_KLEIDI" - ] + _get_preprocessor_flags(), - exported_deps = [ - "//executorch/runtime/backend:interface", - ], - deps = [ - third_party_dep("XNNPACK"), - "//executorch/backends/xnnpack/serialization:xnnpack_flatbuffer_header", - "//executorch/extension/threadpool:threadpool", - "//executorch/runtime/core/exec_aten/util:tensor_util", - "//executorch/runtime/executor:pte_data_map" - ], - # XnnpackBackend.cpp needs to compile with executor as whole - # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) - link_whole = True, - ) + for aten_mode in get_aten_mode_options(): + aten_suffix = "_aten" if aten_mode else "" + runtime.cxx_library( + name = "xnnpack_backend" + aten_suffix, + srcs = native.glob([ + "runtime/*.cpp", + "runtime/profiling/*.cpp", + ]), + headers = native.glob([ + "runtime/*.h", + "runtime/profiling/*.h", + ]), + visibility = [ + "//executorch/exir/backend:backend_lib", + "//executorch/exir/backend/test/...", + "//executorch/backends/xnnpack/test/...", + "//executorch/extension/pybindings/...", + "@EXECUTORCH_CLIENTS", + ], + preprocessor_flags = [ + # Uncomment to enable per operator timings + # "-DENABLE_XNNPACK_PROFILING", + # Uncomment to enable using KleidiAI Kernels + # "-DENABLE_XNNPACK_KLEIDI" + ] + _get_preprocessor_flags(), + exported_deps = [ + "//executorch/runtime/backend:interface" + aten_suffix, + ], + deps = [ + third_party_dep("XNNPACK"), + "//executorch/backends/xnnpack/serialization:xnnpack_flatbuffer_header", + "//executorch/extension/threadpool:threadpool", + "//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix, + "//executorch/runtime/executor:pte_data_map" + aten_suffix, + ], + # XnnpackBackend.cpp needs to compile with executor as whole + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + ) diff --git a/configurations/CMakeLists.txt b/configurations/CMakeLists.txt index d620b722a09..d77ea1633ed 100644 --- a/configurations/CMakeLists.txt +++ b/configurations/CMakeLists.txt @@ -59,7 +59,7 @@ if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) optimized_kernels ${_optimized_native_cpu_ops_lib_portable_kernels_lib} DEPS - executorch + executorch_core ) install(TARGETS optimized_native_cpu_ops_lib DESTINATION lib) diff --git a/devtools/bundled_program/bundled_program.cpp b/devtools/bundled_program/bundled_program.cpp index df4124e0038..913c349a53a 100644 --- a/devtools/bundled_program/bundled_program.cpp +++ b/devtools/bundled_program/bundled_program.cpp @@ -260,9 +260,16 @@ ET_NODISCARD Error load_bundled_input( if (!method_test.ok()) { return method_test.error(); } - + auto test_cases = method_test.get()->test_cases(); + ET_CHECK_OR_RETURN_ERROR( + testset_idx < test_cases->size(), + InvalidArgument, + "testset_idx %zu is out of range [0, %u]", + testset_idx, + test_cases->size()); auto bundled_inputs = - method_test.get()->test_cases()->Get(testset_idx)->inputs(); + test_cases->Get(static_cast(testset_idx)) + ->inputs(); for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) { auto bundled_input = bundled_inputs->GetMutableObject(input_idx); @@ -359,8 +366,16 @@ ET_NODISCARD Error verify_method_outputs( return method_test.error(); } + auto test_cases = method_test.get()->test_cases(); + ET_CHECK_OR_RETURN_ERROR( + testset_idx < test_cases->size(), + InvalidArgument, + "testset_idx %zu is out of range [0, %u]", + testset_idx, + test_cases->size()); auto bundled_expected_outputs = - method_test.get()->test_cases()->Get(testset_idx)->expected_outputs(); + test_cases->Get(static_cast(testset_idx)) + ->expected_outputs(); if (bundled_expected_outputs->size() == 0) { // No bundled expected outputs, so we can't verify the method outputs. diff --git a/devtools/bundled_program/schema/targets.bzl b/devtools/bundled_program/schema/targets.bzl index 532a01e039e..1201458b42f 100644 --- a/devtools/bundled_program/schema/targets.bzl +++ b/devtools/bundled_program/schema/targets.bzl @@ -74,6 +74,7 @@ def define_common_targets(): visibility = [ "//executorch/devtools/bundled_program/...", "//executorch/extension/pybindings/...", + "//executorch/extension/module/...", ], exported_headers = { OUTPUT_BUNDLED_HEADER: ":{}[{}]".format(BUNDLED_GEN_RULE_NAME, OUTPUT_BUNDLED_HEADER), diff --git a/examples/cadence/models/resnet18.py b/examples/cadence/models/resnet18.py new file mode 100644 index 00000000000..ebf0533bd84 --- /dev/null +++ b/examples/cadence/models/resnet18.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +import torch + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + + +from executorch.backends.cadence.aot.export_example import export_model +from torchvision.models import resnet18, ResNet18_Weights + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + model = resnet18(weights=ResNet18_Weights.DEFAULT) + model.eval() + example_inputs = (torch.randn(1, 3, 64, 64),) + + export_model(model, example_inputs) diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index 12385f32d20..4ea735e5717 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -80,14 +80,12 @@ find_package(gflags REQUIRED) # find `executorch` libraries Same as for gflags set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) -if(CMAKE_TOOLCHAIN_IOS OR ANDROID) - target_link_options_shared_lib(executorch) -endif() +target_link_options_shared_lib(executorch) # llama_runner library add_subdirectory(runner) -set(link_libraries gflags) +set(link_libraries executorch gflags) set(_srcs main.cpp) if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) @@ -225,9 +223,5 @@ target_include_directories( target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries}) target_compile_options(llama_main PUBLIC ${_common_compile_options}) -if(APPLE) - target_link_options_shared_lib(executorch) -endif() - # Print all summary executorch_print_configuration_summary() diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 5c8db7f208d..1fdcdcd91fc 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.attention import ( + Attention, ATTENTION_REGISTRY, ForwardOptions, ) @@ -83,19 +84,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + def __init__(self, args: ModelArgs, attention: Attention): + """ + Transformer block with support for pre-norm and post-norm. + Args: + args (ModelArgs): model configuration parameters. + attention (Attention): attention object to use in the transformer + block. See `attention.py` for types of attention. Make sure + the attention type is registered in the ATTENTION_REGISTRY. + """ super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - if args.attention_type not in ATTENTION_REGISTRY: - raise ValueError( - f"Unknown attention type: {args.attention_type}. " - f"Available: {list(ATTENTION_REGISTRY.keys())}" - ) - cls = ATTENTION_REGISTRY[args.attention_type] - self.attention = cls(args, layer_id, rope) + self.attention = attention if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -103,6 +106,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + @classmethod + def from_type(cls, layer_id, args, rope) -> "TransformerBlock": + """ + Create a TransformerBlock with the legacy constructor. + Args: + layer_id (int): the index of the layer. + args (ModelArgs): model configuration parameters. + rope (Rope): the rope object to use for rotary embeddings. + """ + if args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + cls = ATTENTION_REGISTRY[args.attention_type] + attention = cls(args, layer_id, rope) + return TransformerBlock(args, attention) + def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention.forward( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options @@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: class Transformer(nn.Module): - def __init__(self, params: ModelArgs): + def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): + """ + Transformer model. + Args: + params (ModelArgs): model configuration parameters. + layers (nn.ModuleList): list of transformer blocks - see the + `TransformerBlock` type above. + rope (Rope): the rope object to use for rotary embeddings. + """ super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs): if self.apply_embedding else None ) - self.rope = Rope(params) - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.layers = layers + self.rope = rope self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) @@ -212,3 +239,23 @@ def forward( return logits, attn_options_update return logits + + +def construct_transformer(model_args: ModelArgs) -> Transformer: + """ + Construct a Transformer model from the given model arguments. + """ + rope = Rope(model_args) + if model_args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {model_args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + layers = torch.nn.ModuleList() + cls = ATTENTION_REGISTRY[model_args.attention_type] + for layer_id in range(model_args.n_layers): + attention = cls(model_args, layer_id, rope) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) + + return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2c82841c573..d6400c29db8 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,9 +15,10 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope from torchao.utils import TorchAOBaseTensor try: @@ -174,7 +175,7 @@ def __init__(self, **kwargs): # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): # Model itself is loaded in default dtype, fp32. - self.model_ = Transformer(model_args) + self.model_ = construct_transformer(model_args) # Get checkpoint dtype. if checkpoint: self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) diff --git a/examples/models/llama/runner/CMakeLists.txt b/examples/models/llama/runner/CMakeLists.txt index 0807e6fa422..fefee61092d 100644 --- a/examples/models/llama/runner/CMakeLists.txt +++ b/examples/models/llama/runner/CMakeLists.txt @@ -52,8 +52,8 @@ else() add_library(llama_runner SHARED ${_llama_runner__srcs}) endif() -set(llama_runner_deps executorch extension_data_loader extension_module - extension_tensor +set(llama_runner_deps executorch_core extension_data_loader extension_module + extension_tensor extension_flat_tensor ) target_link_libraries(llama_runner PUBLIC ${llama_runner_deps}) diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py index 345f3fad9ba..dc1f9c6cd71 100644 --- a/examples/models/llama/tests/test_pre_quantization_transforms.py +++ b/examples/models/llama/tests/test_pre_quantization_transforms.py @@ -7,7 +7,10 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + Transformer, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, @@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer: vocab_size=32000, ) - model = Transformer(model_args) + model = construct_transformer(model_args) return model diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index a1b6742416e..77b8be5d401 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -2,7 +2,7 @@ import torch from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( @@ -160,10 +160,10 @@ def test_within_transformer(self): n_layers=4, vocab_size=128, ) - mha_transformer = Transformer(config).eval() + mha_transformer = construct_transformer(config).eval() config.attention_type = "static" - static_transformer = Transformer(config).eval() + static_transformer = construct_transformer(config).eval() static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) for mha_layer, static_layer in zip( mha_transformer.layers, static_transformer.layers diff --git a/examples/models/llava/CMakeLists.txt b/examples/models/llava/CMakeLists.txt index eeb6c296dd5..232e83d8b0a 100644 --- a/examples/models/llava/CMakeLists.txt +++ b/examples/models/llava/CMakeLists.txt @@ -89,14 +89,12 @@ endif() # find `executorch` libraries Same as for gflags set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) -if(CMAKE_TOOLCHAIN_IOS OR ANDROID) - target_link_options_shared_lib(executorch) -endif() +target_link_options_shared_lib(executorch) # llava_runner library add_subdirectory(runner) -set(LINK_LIBS gflags) +set(LINK_LIBS executorch gflags) if(NOT LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE) list(APPEND LINK_LIBS torch) endif() @@ -212,9 +210,5 @@ target_include_directories(llava_main PUBLIC ${_common_include_directories}) target_link_libraries(llava_main PUBLIC llava_runner ${link_libraries}) target_compile_options(llava_main PUBLIC ${_common_compile_options}) -if(APPLE) - target_link_options_shared_lib(executorch) -endif() - # Print all summary executorch_print_configuration_summary() diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 351356607c8..7bcf560536c 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -12,7 +12,7 @@ import requests import torch -from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( @@ -66,7 +66,7 @@ def __init__( use_hf_rope=True, max_seq_len=max_seq_len, ) - self.text_model = Transformer(self.text_model_args) + self.text_model = construct_transformer(self.text_model_args) # use custom op for SDPA. if use_sdpa_with_kv_cache_op: self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model) diff --git a/examples/models/llava/runner/CMakeLists.txt b/examples/models/llava/runner/CMakeLists.txt index c694bf87c66..1f9d6fa8e1d 100644 --- a/examples/models/llava/runner/CMakeLists.txt +++ b/examples/models/llava/runner/CMakeLists.txt @@ -40,8 +40,8 @@ add_subdirectory( add_library(llava_runner STATIC ${_llava_runner__srcs}) -set(llava_runner_deps executorch extension_data_loader extension_llm_runner - extension_module extension_tensor +set(llava_runner_deps executorch_core extension_data_loader extension_llm_runner + extension_module extension_tensor extension_flat_tensor ) target_link_libraries(llava_runner PUBLIC ${llava_runner_deps}) diff --git a/examples/models/qwen3/4b_config.json b/examples/models/qwen3/4b_config.json index 0874682bd80..a7a710c7779 100644 --- a/examples/models/qwen3/4b_config.json +++ b/examples/models/qwen3/4b_config.json @@ -13,5 +13,5 @@ "use_hf_rope": true, "attention_qkv_bias": false, "use_qk_norm": true, - "qk_norm_before_repo": true + "qk_norm_before_rope": true } diff --git a/examples/models/qwen3/README.md b/examples/models/qwen3/README.md index 767cbafe03d..d5507d79f2f 100644 --- a/examples/models/qwen3/README.md +++ b/examples/models/qwen3/README.md @@ -25,6 +25,7 @@ python -m examples.models.llama.export_llama \ -X \ --xnnpack-extended-ops \ -qmode 8da4w \ + --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ --output_name="qwen3-0_6b.pte" \ --verbose ``` @@ -40,6 +41,7 @@ python -m examples.models.llama.export_llama \ -X \ --xnnpack-extended-ops \ -qmode 8da4w \ + --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ --output_name="qwen3-1_7b.pte" \ --verbose ``` @@ -55,6 +57,7 @@ python -m examples.models.llama.export_llama \ -X \ --xnnpack-extended-ops \ -qmode 8da4w \ + --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ --output_name="qwen3-4b.pte" \ --verbose ``` diff --git a/exir/backend/test/demos/rpc/ExecutorBackend.cpp b/exir/backend/test/demos/rpc/ExecutorBackend.cpp index 7dc0d2b2373..7632e4ad33c 100644 --- a/exir/backend/test/demos/rpc/ExecutorBackend.cpp +++ b/exir/backend/test/demos/rpc/ExecutorBackend.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -37,6 +38,7 @@ using ::executorch::runtime::MemoryAllocator; using ::executorch::runtime::MemoryManager; using ::executorch::runtime::Method; using ::executorch::runtime::MethodMeta; +using ::executorch::runtime::NamedDataMap; using ::executorch::runtime::Program; using ::executorch::runtime::Result; using ::executorch::runtime::Span; @@ -156,9 +158,13 @@ class ExecutorBackend final : public ::executorch::runtime::BackendInterface { new (client_memory_manager) MemoryManager(client_method_allocator, client_planned_memory); + const NamedDataMap* named_data_map = context.get_named_data_map(); // Construct the client Method - Result method_res = - client_program->load_method("forward", client_memory_manager); + Result method_res = client_program->load_method( + "forward", + client_memory_manager, + /*event_tracer=*/nullptr, + named_data_map); if (!method_res.ok()) { ET_LOG( Error, diff --git a/exir/backend/test/demos/rpc/TARGETS b/exir/backend/test/demos/rpc/TARGETS index 3fdb1d4360a..d8fb426ba6a 100644 --- a/exir/backend/test/demos/rpc/TARGETS +++ b/exir/backend/test/demos/rpc/TARGETS @@ -11,6 +11,7 @@ runtime.python_library( ], visibility = [ "//executorch/exir/backend/test/...", + "//executorch/test/...", ], deps = [ "//caffe2:torch", diff --git a/exir/backend/test/demos/rpc/executor_backend_preprocess.py b/exir/backend/test/demos/rpc/executor_backend_preprocess.py index 0e5b8a8d3d5..19f7a25bd73 100644 --- a/exir/backend/test/demos/rpc/executor_backend_preprocess.py +++ b/exir/backend/test/demos/rpc/executor_backend_preprocess.py @@ -8,6 +8,8 @@ from typing import final, List +from executorch.exir import ExecutorchBackendConfig + from executorch.exir.backend.backend_details import ( BackendDetails, ExportedProgram, @@ -24,10 +26,14 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: + config = ExecutorchBackendConfig() + for spec in compile_specs: + if spec.key == "external_constants": + config.external_constants = True return PreprocessResult( processed_bytes=EdgeProgramManager( edge_programs=edge_program, ) - .to_executorch() + .to_executorch(config) .buffer, ) diff --git a/exir/backend/test/demos/rpc/targets.bzl b/exir/backend/test/demos/rpc/targets.bzl index c5cfb343a6c..486444e400d 100644 --- a/exir/backend/test/demos/rpc/targets.bzl +++ b/exir/backend/test/demos/rpc/targets.bzl @@ -40,6 +40,7 @@ def define_common_targets(): ], visibility = [ "//executorch/exir/backend/test/...", + "//executorch/runtime/executor/test/...", ], deps = [ ":executor_backend", diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 82420f66e6b..d3c2d0a0936 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -682,7 +682,6 @@ def test_dce_recursive(self) -> None: inputs = eager_model.get_random_inputs() gm = export(eager_model, inputs, strict=True).graph_module - self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm)) dead_code_elimination_pass(gm) gm.print_readable() self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm)) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 3a1fe79d8f5..b2f7b8d9f47 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -72,6 +72,7 @@ list( link_libraries executorch extension_data_loader + extension_flat_tensor extension_module extension_runner_util extension_tensor diff --git a/extension/data_loader/CMakeLists.txt b/extension/data_loader/CMakeLists.txt index 0af3fbcc161..6779160bcaf 100644 --- a/extension/data_loader/CMakeLists.txt +++ b/extension/data_loader/CMakeLists.txt @@ -18,7 +18,7 @@ endif() list(TRANSFORM _extension_data_loader__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(extension_data_loader ${_extension_data_loader__srcs}) -target_link_libraries(extension_data_loader executorch) +target_link_libraries(extension_data_loader executorch_core) target_include_directories(extension_data_loader PUBLIC ${EXECUTORCH_ROOT}/..) target_compile_options(extension_data_loader PUBLIC ${_common_compile_options}) diff --git a/extension/flat_tensor/CMakeLists.txt b/extension/flat_tensor/CMakeLists.txt index caacd96b557..d44ed811805 100644 --- a/extension/flat_tensor/CMakeLists.txt +++ b/extension/flat_tensor/CMakeLists.txt @@ -18,7 +18,7 @@ endif() list(TRANSFORM _extension_flat_tensor__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(extension_flat_tensor ${_extension_flat_tensor__srcs}) -target_link_libraries(extension_flat_tensor executorch extension_data_loader) +target_link_libraries(extension_flat_tensor executorch_core) target_include_directories( extension_flat_tensor PUBLIC ${EXECUTORCH_ROOT}/.. diff --git a/extension/module/bundled_module.cpp b/extension/module/bundled_module.cpp new file mode 100644 index 00000000000..083aef141a0 --- /dev/null +++ b/extension/module/bundled_module.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace executorch { +namespace extension { + +namespace { +std::unique_ptr program_data_loader( + const void* bundled_program_ptr) { + auto bundled_program = + bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr); + // the program inside the bundled program + auto program = bundled_program->program(); + return std::make_unique(program->data(), program->size()); +} +} // namespace + +BundledModule::BundledModule( + const void* bundled_program_ptr, + std::unique_ptr memory_allocator, + std::unique_ptr temp_allocator, + std::unique_ptr event_tracer, + std::unique_ptr data_map_loader) + : Module( + program_data_loader(bundled_program_ptr), + std::move(memory_allocator), + std::move(temp_allocator), + std::move(event_tracer), + std::move(data_map_loader)), + bundled_program_ptr_(bundled_program_ptr) {} + +runtime::Result> BundledModule::from_file( + const std::string& file_path, + std::unique_ptr memory_allocator, + std::unique_ptr temp_allocator, + std::unique_ptr event_tracer, + std::unique_ptr data_map_loader) { + auto data_loader_result = FileDataLoader::from(file_path.c_str()); + if (!data_loader_result.ok()) { + return data_loader_result.error(); + } + + auto file_size_result = data_loader_result->size(); + if (!file_size_result.ok()) { + return file_size_result.error(); + } + + size_t file_size = file_size_result.get(); + auto file_data = std::make_unique(file_size); + auto buffer_result = + data_loader_result->load_into(0, file_size, {}, file_data.get()); + if (buffer_result != runtime::Error::Ok) { + return buffer_result; + } + + // Pass ownership of the data to BundledModule + auto bm = std::make_unique( + file_data.release(), + std::move(memory_allocator), + std::move(temp_allocator), + std::move(event_tracer), + std::move(data_map_loader)); + + bm->is_loaded_from_file_ = true; + + return bm; +} + +runtime::Result> BundledModule::execute( + const std::string& method_name, + const size_t testset_idx) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( + *method, bundled_program_ptr_, testset_idx)); + ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); + + const auto outputs_size = method->outputs_size(); + std::vector outputs(outputs_size); + ET_CHECK_OK_OR_RETURN_ERROR( + method->get_outputs(outputs.data(), outputs_size)); + + return outputs; +} + +runtime::Error BundledModule::verify_method_outputs( + const std::string& method_name, + const size_t testset_idx, + double rtol, + double atol) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + return executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( + *method, bundled_program_ptr_, testset_idx, rtol, atol); +} + +} // namespace extension +} // namespace executorch diff --git a/extension/module/bundled_module.h b/extension/module/bundled_module.h new file mode 100644 index 00000000000..d254a2cdcb5 --- /dev/null +++ b/extension/module/bundled_module.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace extension { + +/** + * A facade class for loading bundled programs and executing methods within + * them. + */ +class BundledModule : public Module { + public: + /** + * Constructs an instance with the bundled program buffer pointer. + * + * This constructor reads the program from bundled program buffer to load the + * module with data loader. The bundled program pointer is preserved so that + * the portion outside of program is accessible. + * + * @param[in] bundled_program_ptr A DataLoader used for loading program data. + * @param[in] memory_allocator A MemoryAllocator used for memory management. + * @param[in] temp_allocator A MemoryAllocator to use when allocating + * temporary data during kernel or delegate execution. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + * @param[in] data_map_loader A DataLoader used for loading external weights. + */ + explicit BundledModule( + const void* bundled_program_ptr, + std::unique_ptr memory_allocator = nullptr, + std::unique_ptr temp_allocator = nullptr, + std::unique_ptr event_tracer = nullptr, + std::unique_ptr data_map_loader = nullptr); + + // Disallow copying + BundledModule(const BundledModule&) = delete; + BundledModule& operator=(const BundledModule&) = delete; + // Disallow copying + BundledModule(BundledModule&&) = delete; + BundledModule& operator=(BundledModule&&) = delete; + // Default destructor + ~BundledModule() { + if (is_loaded_from_file_) { + delete[] static_cast(bundled_program_ptr_); + } + } + + /** + * Constructs an instance by loading a bundled program from a file with + * specified memory locking behavior. + * + * @param[in] file_path The path to the ExecuTorch bundled program file to + * load. + * @param[in] memory_allocator A MemoryAllocator used for memory management. + * @param[in] temp_allocator A MemoryAllocator to use when allocating + * temporary data during kernel or delegate execution. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + * @param[in] data_map_loader A DataLoader used for loading external weights. + */ + ET_NODISCARD static runtime::Result> from_file( + const std::string& file_path, + std::unique_ptr memory_allocator = nullptr, + std::unique_ptr temp_allocator = nullptr, + std::unique_ptr event_tracer = nullptr, + std::unique_ptr data_map_loader = nullptr); + + using Module::execute; + + /** + * Execute a specific method with the input value at the given `testset_idx` + * from the bundle to the method. Loads the program and method before + * executing if needed. + * + * This function is a wrapper of `load_bundled_input` in `bundled_program`. + * + * @param[in] method_name The name of the method to execute. + * @param[in] testset_idx The index of the input value to be passed to the + * method. + * + * @returns Return Error::Ok on a successful load, or the error happens during + * execution. + */ + ET_NODISCARD + runtime::Result> execute( + const std::string& method_name, + const size_t testset_idx); + + /** + * Verify the output of a specific method with the expected output from the + * program bundle at the given `testset_idx`. + * + * This function is a wrapper of `verify_method_outputs` in `bundled_program`. + * + * @param[in] method_name The name of the method to extract outputs from. + * @param[in] testset_idx The index of expected output needs to be compared. + * @param[in] rtol Relative tolerance used for data comparsion. + * @param[in] atol Absolute tolerance used for data comparsion. + * + * @returns Return Error::Ok if two outputs match, or the error happens during + * execution. + */ + ET_NODISCARD + runtime::Error verify_method_outputs( + const std::string& method_name, + const size_t testset_idx, + double rtol = 1e-5, + double atol = 1e-8); + + private: + const void* bundled_program_ptr_; + bool is_loaded_from_file_ = false; +}; + +} // namespace extension +} // namespace executorch diff --git a/extension/module/module.cpp b/extension/module/module.cpp index ec01323edc7..6c534b8d560 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -302,15 +302,5 @@ runtime::Error Module::set_output( output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); } -ET_NODISCARD inline runtime::Result Module::get_method( - const std::string& method_name) { - ET_CHECK_OR_RETURN_ERROR( - methods_.count(method_name) > 0, - InvalidArgument, - "no such method in program: %s", - method_name.c_str()); - return methods_[method_name].method.get(); -} - } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index 201887b9ccc..0c4d4779bea 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -491,16 +491,6 @@ class Module { std::unique_ptr data_map_; protected: - /** - * Get a method by method name. - * - * @param[in] method_name The name of the method to get. - * - * @returns A Result object containing either a pointer to the requested - * method or an error to indicate failure. - */ - ET_NODISCARD inline runtime::Result get_method( - const std::string& method_name); std::unordered_map methods_; friend class ExecuTorchJni; diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index d8019ce9c4e..3e449da5e14 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -28,6 +28,28 @@ def define_common_targets(): "//executorch/extension/flat_tensor:flat_tensor_data_map" + aten_suffix, ], exported_deps = [ - "//executorch/runtime/executor:program" + aten_suffix, + "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, + ], + ) + + runtime.cxx_library( + name = "bundled_module" + aten_suffix, + srcs = [ + "bundled_module.cpp", + ], + exported_headers = [ + "bundled_module.h", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/extension/data_loader:buffer_data_loader", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/devtools/bundled_program:runtime" + aten_suffix, + "//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs", + ], + exported_deps = [ + "//executorch/extension/module:module" + aten_suffix, ], ) diff --git a/extension/module/test/bundled_module_test.cpp b/extension/module/test/bundled_module_test.cpp new file mode 100644 index 00000000000..a07c5dd5486 --- /dev/null +++ b/extension/module/test/bundled_module_test.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using namespace ::executorch::extension; +using namespace ::executorch::runtime; + +class BundledModuleTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + std::string resources_path; + if (const char* env = std::getenv("RESOURCES_PATH")) { + resources_path = env; + } + pte_path_ = std::getenv("ET_MODULE_PTE_PATH"); + bpte_path_ = resources_path + "/bundled_program.bpte"; + } + + static inline std::string bpte_path_; + static inline std::string pte_path_; +}; + +TEST_F(BundledModuleTest, TestExecute) { + auto bundled_module_output = BundledModule::from_file(bpte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + auto& bundled_module = bundled_module_output.get(); + + auto outputs = bundled_module->execute("forward", /*testset_idx=*/0); + EXPECT_EQ(bundled_module->Module::is_loaded(), true); + EXPECT_EQ(outputs.error(), Error::Ok); + + auto status = + bundled_module->verify_method_outputs("forward", /*testset_idx=*/0); + EXPECT_EQ(status, Error::Ok); +} + +TEST_F(BundledModuleTest, TestNonExistBPFile) { + auto bundled_module_output = + BundledModule::from_file("/path/to/nonexistent/file.bpte"); + EXPECT_EQ(bundled_module_output.error(), Error::AccessFailed); +} + +TEST_F(BundledModuleTest, TestNonBPFile) { + auto bundled_module_output = BundledModule::from_file(pte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + + auto& bundled_module = bundled_module_output.get(); + + auto outputs = bundled_module->execute("forward", /*testset_idx=*/0); + EXPECT_EQ(bundled_module->Module::is_loaded(), false); + EXPECT_EQ(outputs.error(), Error::InvalidArgument); + + auto status = + bundled_module->verify_method_outputs("forward", /*testset_idx=*/0); + EXPECT_EQ(status, Error::InvalidArgument); +} + +TEST_F(BundledModuleTest, TestExecuteInvalidMethod) { + auto bundled_module_output = BundledModule::from_file(bpte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + auto& bundled_module = bundled_module_output.get(); + + auto outputs = + bundled_module->execute("non_existent_method", /*testset_idx=*/0); + EXPECT_EQ(outputs.error(), Error::InvalidArgument); +} + +TEST_F(BundledModuleTest, TestExecuteInvalidIdx) { + auto bundled_module_output = BundledModule::from_file(bpte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + auto& bundled_module = bundled_module_output.get(); + + auto outputs = bundled_module->execute("forward", /*testset_idx=*/10000); + EXPECT_EQ(outputs.error(), Error::InvalidArgument); +} + +TEST_F(BundledModuleTest, TestVerifyInvalidMethod) { + auto bundled_module_output = BundledModule::from_file(bpte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + auto& bundled_module = bundled_module_output.get(); + + auto outputs = bundled_module->execute("forward", /*testset_idx=*/0); + EXPECT_EQ(bundled_module->Module::is_loaded(), true); + EXPECT_EQ(outputs.error(), Error::Ok); + + auto status = bundled_module->verify_method_outputs( + "non_existent_method", /*testset_idx=*/0); + EXPECT_EQ(status, Error::InvalidArgument); +} + +TEST_F(BundledModuleTest, TestVerifyInvalidIdx) { + auto bundled_module_output = BundledModule::from_file(bpte_path_.c_str()); + EXPECT_EQ(bundled_module_output.error(), Error::Ok); + auto& bundled_module = bundled_module_output.get(); + + auto outputs = bundled_module->execute("forward", /*testset_idx=*/0); + EXPECT_EQ(bundled_module->Module::is_loaded(), true); + EXPECT_EQ(outputs.error(), Error::Ok); + + auto status = + bundled_module->verify_method_outputs("forward", /*testset_idx=*/10000); + EXPECT_EQ(status, Error::InvalidArgument); +} diff --git a/extension/module/test/resources/README.md b/extension/module/test/resources/README.md new file mode 100644 index 00000000000..026042ab121 --- /dev/null +++ b/extension/module/test/resources/README.md @@ -0,0 +1,7 @@ +## Resources + +### bundled_program.bpte + + ``` + python3 extension/module/test/resources/gen_bundled_program.py + ``` diff --git a/extension/module/test/resources/bundled_program.bpte b/extension/module/test/resources/bundled_program.bpte new file mode 100644 index 00000000000..ea5b42c03d0 Binary files /dev/null and b/extension/module/test/resources/bundled_program.bpte differ diff --git a/extension/module/test/resources/gen_bundled_program.py b/extension/module/test/resources/gen_bundled_program.py new file mode 100644 index 00000000000..f1fa0a4a7e3 --- /dev/null +++ b/extension/module/test/resources/gen_bundled_program.py @@ -0,0 +1,99 @@ +import torch + +from executorch.devtools import BundledProgram + +from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite +from executorch.devtools.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) + +from executorch.exir import to_edge_transform_and_lower +from torch.export import export, export_for_training + + +# Step 1: ExecuTorch Program Export +class SampleModel(torch.nn.Module): + """An example model with multi-methods. Each method has multiple input and single output""" + + def __init__(self) -> None: + super().__init__() + self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32)) + self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32)) + + def forward(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + z = x.clone() + torch.mul(self.a, x, out=z) + y = x.clone() + torch.add(z, self.b, out=y) + torch.add(y, q, out=y) + return y + + +def main() -> None: + """Sample code to generate bundled program and save it to file. It is the same as in https://pytorch.org/executorch/0.6/bundled-io.html#emit-example""" + # Inference method name of SampleModel we want to bundle testcases to. + # Notices that we do not need to bundle testcases for every inference methods. + method_name = "forward" + model = SampleModel() + + # Inputs for graph capture. + capture_input = ( + (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), + (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), + ) + + # Export method's FX Graph. + method_graph = export( + export_for_training(model, capture_input).module(), + capture_input, + ) + + # Emit the traced method into ET Program. + et_program = to_edge_transform_and_lower(method_graph).to_executorch() + + # Step 2: Construct MethodTestSuite for Each Method + + # Prepare the Test Inputs. + + # Number of input sets to be verified + n_input = 10 + + # Input sets to be verified. + inputs = [ + # Each list below is a individual input set. + # The number of inputs, dtype and size of each input follow Program's spec. + [ + (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), + (torch.rand(2, 2) - 0.5).to(dtype=torch.int32), + ] + for _ in range(n_input) + ] + + # Generate Test Suites + method_test_suites = [ + MethodTestSuite( + method_name=method_name, + test_cases=[ + MethodTestCase( + inputs=input, + expected_outputs=(getattr(model, method_name)(*input),), + ) + for input in inputs + ], + ), + ] + + # Step 3: Generate BundledProgram + bundled_program = BundledProgram(et_program, method_test_suites) + + # Step 4: Serialize BundledProgram to flatbuffer. + serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer( + bundled_program + ) + save_path = "bundled_program.bpte" + with open(save_path, "wb") as f: + f.write(serialized_bundled_program) + + +if __name__ == "__main__": + main() diff --git a/extension/module/test/targets.bzl b/extension/module/test/targets.bzl index e308ca89c30..e09b43e356d 100644 --- a/extension/module/test/targets.bzl +++ b/extension/module/test/targets.bzl @@ -42,3 +42,30 @@ def define_common_targets(is_fbcode=False): "-Wno-error=deprecated-declarations", ], ) + + runtime.cxx_test( + name = "bundled_test" + aten_suffix, + srcs = [ + "bundled_module_test.cpp", + ], + deps = [ + "//executorch/kernels/portable:generated_lib" + aten_suffix, + "//executorch/extension/module:bundled_module" + aten_suffix, + "//executorch/extension/tensor:tensor" + aten_suffix, + ], + env = { + "RESOURCES_PATH": "$(location :resources)/resources", + "ET_MODULE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", + }, + platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. + compiler_flags = [ + "-Wno-error=deprecated-declarations", + ], + ) + + runtime.filegroup( + name = "resources", + srcs = native.glob([ + "resources/**", + ]), + ) diff --git a/extension/runner_util/CMakeLists.txt b/extension/runner_util/CMakeLists.txt index 19aa884fd77..3483b3babf3 100644 --- a/extension/runner_util/CMakeLists.txt +++ b/extension/runner_util/CMakeLists.txt @@ -18,7 +18,7 @@ endif() list(TRANSFORM _extension_runner_util__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(extension_runner_util ${_extension_runner_util__srcs}) -target_link_libraries(extension_runner_util executorch) +target_link_libraries(extension_runner_util executorch_core) target_include_directories(extension_runner_util PUBLIC ${EXECUTORCH_ROOT}/..) target_compile_options(extension_runner_util PUBLIC ${_common_compile_options}) diff --git a/extension/runner_util/targets.bzl b/extension/runner_util/targets.bzl index 3ab0c26cc72..75b9e1ef905 100644 --- a/extension/runner_util/targets.bzl +++ b/extension/runner_util/targets.bzl @@ -23,6 +23,6 @@ def define_common_targets(): ], exported_deps = [ "//executorch/runtime/core/exec_aten:lib" + aten_suffix, - "//executorch/runtime/executor:program" + aten_suffix, + "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, ], ) diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index 693be68c35e..85c829469be 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -71,7 +71,7 @@ target_compile_options(optimized_kernels PUBLIC ${_common_compile_options}) # # optimized_ops_lib: Register optimized ops kernels into Executorch runtime gen_operators_lib( - LIB_NAME "optimized_ops_lib" KERNEL_LIBS optimized_kernels DEPS executorch + LIB_NAME "optimized_ops_lib" KERNEL_LIBS optimized_kernels DEPS executorch_core ) install( diff --git a/kernels/portable/cpu/op_native_dropout.cpp b/kernels/portable/cpu/op_native_dropout.cpp new file mode 100644 index 00000000000..1c4d177e8ed --- /dev/null +++ b/kernels/portable/cpu/op_native_dropout.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch::executor::native { +std::tuple native_dropout_out( + KernelRuntimeContext& ctx, + const Tensor& input, + double prob, + torch::executor::optional train, + Tensor& out, + Tensor& mask) { + std::tuple ret(out, mask); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dtype(input, out), InvalidArgument, ret); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(input, out, mask), InvalidArgument, ret); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, input.sizes()) == Error::Ok, + InvalidArgument, + ret); + ET_KERNEL_CHECK( + ctx, + resize_tensor(mask, input.sizes()) == Error::Ok, + InvalidArgument, + ret); + ET_KERNEL_CHECK(ctx, tensor_is_bool_type(mask), InvalidArgument, ret); + ET_KERNEL_CHECK_MSG( + ctx, + prob >= 0 && prob <= 1, + InvalidArgument, + ret, + "dropout probability has to be between 0 and 1 but got %f", + prob); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "native_dropout.out"; + if ((!train.has_value() || train.value()) && prob != 0) { + { + std::mt19937 gen((std::random_device())()); + std::uniform_real_distribution dist; + bool* const mask_data_ptr = mask.mutable_data_ptr(); + for (const auto ii : c10::irange(mask.numel())) { + mask_data_ptr[ii] = dist(gen) >= prob; + } + } + ET_SWITCH_FLOATHBF16_TYPES( + input.scalar_type(), ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const auto val, const auto mask_val) { + if (!mask_val) { + return static_cast(0); + } + return val; + }, + ctx, + input, + utils::SupportedTensorDtypes::FLOATHBF16, + mask, + // TODO: should really be just BOOL + utils::SupportedTensorDtypes::BOOL_OR_BYTE, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + } else if (input.numel() > 0) { + std::memcpy(out.mutable_data_ptr(), input.data_ptr(), input.nbytes()); + std::memset(mask.mutable_data_ptr(), true, mask.nbytes()); + } + return ret; +} + +} // namespace torch::executor::native diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index ab04d3b26ac..466e015e31d 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -627,6 +627,12 @@ - arg_meta: null kernel_name: torch::executor::narrow_copy_out +- op: native_dropout.out + kernels: + - arg_meta: null + kernel_name: torch::executor::native_dropout_out + tags: nondeterministic_seeded + - op: native_group_norm.out kernels: - arg_meta: null diff --git a/kernels/quantized/CMakeLists.txt b/kernels/quantized/CMakeLists.txt index 29058e9b11d..149db0c17f6 100644 --- a/kernels/quantized/CMakeLists.txt +++ b/kernels/quantized/CMakeLists.txt @@ -142,13 +142,13 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode" endif() add_library(quantized_kernels ${_quantized_kernels__srcs}) -target_link_libraries(quantized_kernels PRIVATE executorch) +target_link_libraries(quantized_kernels PRIVATE executorch_core) target_compile_options(quantized_kernels PUBLIC ${_common_compile_options}) # Build a library for _quantized_kernels_srcs # # quantized_ops_lib: Register quantized ops kernels into Executorch runtime gen_operators_lib( - LIB_NAME "quantized_ops_lib" KERNEL_LIBS quantized_kernels DEPS executorch + LIB_NAME "quantized_ops_lib" KERNEL_LIBS quantized_kernels DEPS executorch_core ) install( diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index deb61410b10..a56fc6cab22 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -186,6 +186,7 @@ set(all_test_sources "op_mul_test.cpp" "op_pow_test.cpp" "op_native_batch_norm_test.cpp" + "op_native_dropout_test.cpp" "op_native_group_norm_test.cpp" "op_native_layer_norm_test.cpp" "op_ne_test.cpp" diff --git a/kernels/test/op_native_dropout_test.cpp b/kernels/test/op_native_dropout_test.cpp new file mode 100644 index 00000000000..931205f54a5 --- /dev/null +++ b/kernels/test/op_native_dropout_test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpNativeDropoutTest : public OperatorTest { + protected: + void op_native_dropout_out( + const Tensor& self, + double prob, + executorch::aten::optional train, + Tensor& out, + Tensor& mask) { + torch::executor::aten::native_dropout_outf( + context_, self, prob, train, out, mask); + } + + template + void test_dropout() { + TensorFactory tf; + TensorFactory tf_bool; + const std::vector sizes = {3, 2}; + Tensor in = tf.make(sizes, {1, 2, 3, 4, 5, 6}); + Tensor out = tf.zeros(sizes); + Tensor mask = tf_bool.zeros(sizes); + + bool* const mask_data = mask.mutable_data_ptr(); + auto expect_no_drops = [&]() { + EXPECT_TENSOR_CLOSE(out, in); + for (const auto ii : c10::irange(mask.numel())) { + EXPECT_TRUE(mask_data[ii]); + mask_data[ii] = false; + } + }; + + op_native_dropout_out(in, 0, true, out, mask); + expect_no_drops(); + + op_native_dropout_out(in, 0, false, out, mask); + expect_no_drops(); + + op_native_dropout_out(in, 1, false, out, mask); + expect_no_drops(); + + op_native_dropout_out(in, 1, true, out, mask); + auto* const out_data = out.mutable_data_ptr(); + for (const auto ii : c10::irange(out.numel())) { + EXPECT_EQ(out_data[ii], CTYPE(0)); + } + for (const auto ii : c10::irange(mask.numel())) { + EXPECT_FALSE(mask_data[ii]); + mask_data[ii] = 0; + } + } +}; + +TEST_F(OpNativeDropoutTest, Basic) { +#define TEST_ENTRY(ctype, dtype) test_dropout(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpNativeDropoutTest, ProbabilityRangeCheck) { + TensorFactory tf_float; + TensorFactory tf_bool; + const std::vector sizes = {2, 3}; + Tensor a = tf_float.ones(sizes); + Tensor out = tf_float.zeros(sizes); + Tensor mask = tf_bool.zeros(sizes); + ET_EXPECT_KERNEL_FAILURE( + context_, op_native_dropout_out(a, -1, true, out, mask)); +} + +TEST_F(OpNativeDropoutTest, MaskBoolCheck) { + TensorFactory tf_float; + TensorFactory tf_byte; + const std::vector sizes = {2, 3}; + Tensor a = tf_float.ones(sizes); + Tensor out = tf_float.zeros(sizes); + Tensor mask_byte = tf_byte.zeros(sizes); + Tensor mask_float = tf_float.zeros(sizes); + ET_EXPECT_KERNEL_FAILURE( + context_, op_native_dropout_out(a, 0.5, true, out, mask_byte)); + ET_EXPECT_KERNEL_FAILURE( + context_, op_native_dropout_out(a, 0.5, true, out, mask_float)); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 2372ba54c73..c1824674fd4 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -272,6 +272,7 @@ def define_common_targets(): _common_op_test("op_mul_test", ["aten", "portable", "optimized"]) _common_op_test("op_narrow_copy_test", ["aten", "portable"]) _common_op_test("op_native_batch_norm_test", ["aten", "portable"]) + _common_op_test("op_native_dropout_test", ["aten", "portable"]) _common_op_test("op_native_group_norm_test", ["aten", "portable"]) _common_op_test("op_native_layer_norm_test", ["aten", "portable", "optimized"]) _common_op_test("op_ne_test", ["aten", "portable"]) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index b7cd7fa8d12..ac799d3e14e 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -329,6 +329,8 @@ Result Method::get_num_external_constants() { } Error Method::parse_external_constants(const NamedDataMap* named_data_map) { + ET_CHECK_OR_RETURN_ERROR( + named_data_map != nullptr, InvalidState, "named_data_map is null"); auto flatbuffer_values = serialization_plan_->values(); size_t n_value = flatbuffer_values->size(); @@ -372,6 +374,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) { Result tensor_layout = named_data_map->get_metadata(key); if (!tensor_layout.ok()) { + ET_LOG(Info, "Failed to get metadata for key %s", key); return tensor_layout.error(); } // Check external tensor compatibility. diff --git a/runtime/executor/test/backend_data_separation_test.cpp b/runtime/executor/test/backend_data_separation_test.cpp new file mode 100644 index 00000000000..32daf3686fc --- /dev/null +++ b/runtime/executor/test/backend_data_separation_test.cpp @@ -0,0 +1,101 @@ + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FlatTensorDataMap; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::Method; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::testing::ManagedMemoryManager; +using torch::executor::util::FileDataLoader; + +constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U; +constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U; + +class BackendDataSeparationTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Make sure that the backend has been registered. Safe to call multiple + // times. Doing this at runtime ensures that it's only registered if these + // tests are run. + ASSERT_EQ(example::register_executor_backend(), Error::Ok); + + // Create data loaders. + Result linear_program_loader = FileDataLoader::from( + std::getenv("ET_MODULE_LINEAR_DELEGATE_PROGRAM_PATH")); + ASSERT_EQ(linear_program_loader.error(), Error::Ok); + linear_program_loader_ = std::make_unique( + std::move(linear_program_loader.get())); + + Result linear_data_loader = + FileDataLoader::from(std::getenv("ET_MODULE_LINEAR_DATA_PATH")); + ASSERT_EQ(linear_data_loader.error(), Error::Ok); + linear_data_loader_ = + std::make_unique(std::move(linear_data_loader.get())); + + // Create programs. + Result linear_program = Program::load( + linear_program_loader_.get(), + Program::Verification::InternalConsistency); + ASSERT_EQ(linear_program.error(), Error::Ok); + linear_program_ = + std::make_unique(std::move(linear_program.get())); + + Result linear_data_map = + FlatTensorDataMap::load(linear_data_loader_.get()); + EXPECT_EQ(linear_data_map.error(), Error::Ok); + linear_data_map_ = + std::make_unique(std::move(linear_data_map.get())); + + ET_LOG( + Info, + "setup done, named_data_map_ = %lu", + linear_data_map_->get_num_keys().get()); + } + + private: + std::unique_ptr linear_program_loader_; + std::unique_ptr linear_data_loader_; + + protected: + std::unique_ptr linear_program_; + std::unique_ptr linear_data_map_; +}; + +TEST_F(BackendDataSeparationTest, TestSeparation) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = linear_program_->load_method( + "forward", + &mmm.get(), + /*event_tracer=*/nullptr, + /*named_data_map=*/linear_data_map_.get()); + ASSERT_EQ(method.error(), Error::Ok); + + // Can execute the method. + Error err = method->execute(); + ASSERT_EQ(err, Error::Ok); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 01f0f91ea1a..39ff0668d5d 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -240,6 +240,28 @@ def define_common_targets(is_fbcode = False): }, ) + runtime.cxx_test( + name = "backend_data_separation_test", + srcs = [ + "backend_data_separation_test.cpp", + ], + deps = [ + ":managed_memory_manager", + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/exir/backend/test/demos/rpc:executor_backend", + "//executorch/exir/backend/test/demos/rpc:executor_backend_register", + "//executorch/extension/flat_tensor:flat_tensor_data_map", + ], + env = { + # The tests use these vars to find the program files to load. + # Uses an fbcode target path because the authoring/export tools + # intentionally don't work in xplat (since they're host-only + # tools). + "ET_MODULE_LINEAR_DELEGATE_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_executor_backend_program_and_data[ModuleLinear-e.pte])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", + }, + ) runtime.cxx_test( name = "memory_manager_test", srcs = [ diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index d0c39bcf17f..3bfc7fdf00f 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -883,6 +883,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:normalization_ops_util", ], ), + op_target( + name = "op_native_dropout", + deps = [ + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), op_target( name = "op_native_group_norm", deps = [ diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index f7b2f354373..cbfdfaedab3 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -20,9 +20,13 @@ from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) +from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( # noqa: F401 + ExecutorBackend, +) from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass, ) @@ -150,13 +154,18 @@ def __init__(self, fn, method_name=method_name): def forward(self, *args, **kwargs): return getattr(self.fn, self.method_name)(*args, **kwargs) - exported_program = export(WrapperModule(eager_module), args=inputs, strict=True) + if method_name != "forward": + # Only require wrapper module if we're exporting a specific method other than forward. + exported_program = export(WrapperModule(eager_module), args=inputs, strict=True) + else: + exported_program = export(eager_module, args=inputs, strict=True) edge_config = EdgeCompileConfig(_check_ir_validity=False) et_config = exir.ExecutorchBackendConfig( extract_delegate_segments=extract_delegate_segments, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, + external_constants=external_constants, ) if backend_id == "XnnpackBackend": @@ -181,7 +190,10 @@ def forward(self, *args, **kwargs): else: edge: exir.EdgeProgramManager = to_edge(exported_program) lowered_module = to_backend( # type: ignore[call-arg] - backend_id, edge.exported_program(), compile_specs=[] + backend_id, + edge.exported_program(), + # Just for the demo executor_backend. + compile_specs=[CompileSpec(key="external_constants", value=b"")], ) class CompositeModule(nn.Module): diff --git a/test/models/export_program.py b/test/models/export_program.py index 5387df24aad..e13b63eaf74 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -146,6 +146,19 @@ def get_random_inputs(self): return (torch.ones(2, 2, dtype=torch.float),) +# Used for program-data-separation. +class ModuleLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + return (torch.randn(3),) + + class ModuleMultipleEntry(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/models/targets.bzl b/test/models/targets.bzl index db0f410d727..9e26a6c123b 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -95,6 +95,7 @@ def define_common_targets(): # Class names of nn.Modules for :exported_programs to export. MODULES_AND_DATA_TO_EXPORT = [ "ModuleAddMul", + "ModuleLinear", "ModuleSimpleTrain", ] @@ -104,6 +105,8 @@ def define_common_targets(): outs = { "ModuleAddMul.pte": ["ModuleAddMulProgram.pte"], "ModuleAddMul.ptd": ["ModuleAddMulProgram.ptd"], + "ModuleLinear.pte": ["ModuleLinearProgram.pte"], + "ModuleLinear.ptd": ["ModuleLinearProgram.ptd"], "ModuleSimpleTrainProgram.pte": ["ModuleSimpleTrainProgram.pte"], "ModuleSimpleTrain.ptd": ["ModuleSimpleTrainProgram.ptd"], }, @@ -146,7 +149,7 @@ def define_common_targets(): deps = [ ":export_delegated_program_lib", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", - + "//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess", ], visibility = [], # Private ) @@ -225,3 +228,23 @@ def define_common_targets(): "//executorch/test/...", ], ) + + # Export with demo ExecutorBackend for program-data separation test. + runtime.genrule( + name = "exported_executor_backend_program_and_data", + cmd = "$(exe :export_delegated_program)" + + " --modules ModuleLinear" + + " --backend_id ExecutorBackend" + + " --external_constants" + + " --outdir $OUT", + + outs = { + "ModuleLinear-e.pte": ["ModuleLinear-e.pte"], + }, + default_outs = ["."], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/extension/flat_tensor/test/...", + "//executorch/test/...", + ], + ) diff --git a/tools/cmake/cmake_deps.toml b/tools/cmake/cmake_deps.toml index 18c1130bc82..748eb775921 100644 --- a/tools/cmake/cmake_deps.toml +++ b/tools/cmake/cmake_deps.toml @@ -178,7 +178,6 @@ filters = [ ] deps = [ "executorch_core", - "executorch", ] [targets.extension_flat_tensor_schema] @@ -199,7 +198,6 @@ filters = [ deps = [ "extension_flat_tensor_schema", "executorch_core", - "executorch", ] [targets.extension_module] @@ -210,9 +208,9 @@ filters = [ ".cpp$", ] deps = [ - "executorch", "executorch_core", "extension_data_loader", + "extension_flat_tensor", ] [targets.extension_runner_util] @@ -223,7 +221,6 @@ filters = [ ".cpp$", ] deps = [ - "executorch", "executorch_core", ] @@ -238,6 +235,8 @@ deps = [ "executorch", "executorch_core", "extension_module", + "extension_data_loader", + "extension_flat_tensor", "extension_runner_util", "extension_tensor", ] @@ -454,6 +453,7 @@ deps = [ "executorch", "executorch_core", "extension_data_loader", + "extension_flat_tensor", "extension_module", "extension_threadpool", "optimized_cpublas", diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 56c7fa2d7d4..a8e756fbb77 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -66,6 +66,7 @@ set(lib_list etdump bundled_program extension_data_loader + extension_flat_tensor ${FLATCCRT_LIB} coreml_util coreml_inmemoryfs