From eaf383a14a965db0dffd46fde3296409ed5785ed Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 19 Aug 2024 18:18:38 +0200 Subject: [PATCH] Add pass to convert split to many slice Differential Revision: D61211922 Pull Request resolved: https://github.com/pytorch/executorch/pull/4562 --- backends/arm/arm_partitioner.py | 1 + backends/arm/operators/op_slice.py | 2 + backends/arm/passes/arm_pass_manager.py | 4 + backends/arm/passes/convert_split_to_slice.py | 70 +++++++++ backends/arm/quantizer/arm_quantizer_utils.py | 4 + backends/arm/test/ops/test_slice.py | 2 +- backends/arm/test/ops/test_split.py | 139 ++++++++++++++++++ backends/arm/test/runner_utils.py | 35 +++-- backends/arm/test/tester/arm_tester.py | 2 +- 9 files changed, 241 insertions(+), 18 deletions(-) create mode 100644 backends/arm/passes/convert_split_to_slice.py create mode 100644 backends/arm/test/ops/test_split.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 72b9c48548..8726533b34 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -43,6 +43,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.avg_pool2d.default, diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 8d59835ff0..e562e0724e 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -40,6 +40,8 @@ def define_node( shape = input_node.shape dim = dim.number end = (shape[dim] + end.number) % shape[dim] + if end == 0: + end = shape[dim] size = end - start.number assert size > 0 assert size <= shape[dim] diff --git a/backends/arm/passes/arm_pass_manager.py b/backends/arm/passes/arm_pass_manager.py index 123146a325..054d823dbb 100644 --- a/backends/arm/passes/arm_pass_manager.py +++ b/backends/arm/passes/arm_pass_manager.py @@ -12,6 +12,9 @@ from executorch.backends.arm.passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) +from executorch.backends.arm.passes.convert_split_to_slice import ( + ConvertSplitToSlicePass, +) from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -28,6 +31,7 @@ def transform_to_backend_pipeline( """Apply passes before transforming program to backend""" self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) + self.add_pass(ConvertSplitToSlicePass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/passes/convert_split_to_slice.py b/backends/arm/passes/convert_split_to_slice.py new file mode 100644 index 0000000000..ff978d4d9e --- /dev/null +++ b/backends/arm/passes/convert_split_to_slice.py @@ -0,0 +1,70 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.fx +from executorch.backends.arm.tosa_mapping import extract_tensor_meta +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class ConvertSplitToSlicePass(ExportPass): + """ + Replace a split operation with many slice operations. + """ + + split_ops = ( + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, + ) + slice = exir_ops.edge.aten.slice_copy.Tensor + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + if node.target not in self.split_ops: + continue + + # Get useful variables + split_node = node + input_node = split_node.all_input_nodes[0] + output_nodes = split_node.users.copy() + _, shape, _ = extract_tensor_meta(input_node.meta) + rank = len(shape) + split_lengths = split_node.args[1] + dim = split_node.args[2] if len(split_node.args) > 2 else 0 + dim = (dim + rank) % rank + + assert ( + sum(split_lengths) == shape[dim] + ), "Given split lengths don't sum up to the size of the dimension." + + # Convert split argument 'split_lengths' to slice arguments start and end. + starts = [0] * len(split_lengths) + ends = [0] * len(split_lengths) + start = 0 + end = 0 + for i, split_length in enumerate(split_lengths): + end = start + split_length + starts[i] = start + ends[i] = end + start = end + + # Output nodes are of type getitem + # Create one slice node for each output node with matching argumetns. + with graph_module.graph.inserting_before(split_node): + for output_node in output_nodes: + index = output_node.args[1] + slice_node = graph.create_node( + "call_function", + self.slice, + (input_node, dim, starts[index], ends[index]), + ) + slice_node.meta = split_node.meta.copy() + slice_node.meta["val"] = slice_node.meta["val"][index] + output_node.replace_input_with(split_node, slice_node) + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 89703f89b0..c5da32a40a 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -9,6 +9,7 @@ # Utility functions for ArmQuantizer # +import operator from typing import Callable, cast, List import torch @@ -141,8 +142,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.view_copy.default, torch.ops.aten.view.default, torch.ops.aten.slice.Tensor, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, + operator.getitem, ] diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index a1c1e29cbc..14874df156 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor): elif x.dim() == 3: return x[0:7, 0:1, 0:8] elif x.dim() == 4: - return x[:, 2:5, 3:5, 4:5] + return x[:, 2:5, 3:5, 4:10] def _test_slice_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: torch.Tensor diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py new file mode 100644 index 0000000000..bc998179c0 --- /dev/null +++ b/backends/arm/test/ops/test_split.py @@ -0,0 +1,139 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + +test_data_t = tuple[torch.Tensor, int | list[int], int] + + +class TestSimpleSplit(unittest.TestCase): + class Split(torch.nn.Module): + + test_data: list[tuple[test_data_t]] = [ + ((torch.rand(10), 2, 0),), + ((torch.rand(10, 10), 3, 1),), + ((torch.rand(10, 10), 4, -1),), + ((torch.rand(10, 15, 10), [2, 2, 11], 1),), + ((torch.rand(4, 4, 4, 4), 2, 0),), + ((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),), + ] + + def forward( + self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int + ): + return x.split(split_size=split_size_or_sections, dim=dim) + + class SplitWithSizes(torch.nn.Module): + def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int): + return x.split_with_sizes(split_sizes=split_sizes, dim=dim) + + class SplitSingleOut(torch.nn.Module): + def forward( + self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int + ): + return x.split(split_size=split_size_or_sections, dim=dim)[1] + + class SplitTwoOut(torch.nn.Module): + def forward( + self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int + ): + return x.split(split_size=split_size_or_sections, dim=dim)[1:3] + + def _test_split_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .check( + [ + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default" + ] + ) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_split_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t + ): + + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_split_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: test_data_t + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check(["torch.ops.aten.split.Tensor"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Split.test_data) + def test_split_tosa_MI(self, test_data: test_data_t): + self._test_split_tosa_MI_pipeline(self.Split(), test_data) + + @parameterized.expand([Split.test_data[3], Split.test_data[5]]) + def test_split_with_sizes_tosa_MI(self, test_data: test_data_t): + assert isinstance(test_data[1], list) + self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data) + + @parameterized.expand(Split.test_data) + def test_split_n_out_tosa_MI(self, test_data: test_data_t): + self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data) + self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data) + + @parameterized.expand(Split.test_data) + def test_split_tosa_BI(self, test_data: test_data_t): + self._test_split_tosa_BI_pipeline(self.Split(), test_data) + + # Fails during Vela compilation when trying to use a Tuple as a Named tuple, + # Could be Vela Issue, wait until Regor. + @parameterized.expand(Split.test_data) + @unittest.expectedFailure + def test_split_u55_BI(self, test_data: test_data_t): + self._test_split_u55_BI_pipeline(self.Split(), test_data) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 58c99a9201..4e3b447103 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -202,7 +202,7 @@ def set_timeout(self, timeout: int): def run_corstone300( self, inputs: Tuple[torch.Tensor], - ) -> torch.Tensor: + ) -> list[torch.Tensor]: assert ( self._has_init_run @@ -268,12 +268,12 @@ def run_corstone300( tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(inputs[0].shape) - return tosa_ref_output + return [tosa_ref_output] def run_tosa_ref_model( self, inputs: Tuple[torch.Tensor], - ) -> torch.Tensor: + ) -> list[torch.Tensor]: """ Run TOSA reference model using the tosa_refence_model program. @@ -369,23 +369,26 @@ def run_tosa_ref_model( # Load desc.json, just to get the name of the output file above with open(desc_file_path) as f: desc_json = json.load(f) - ofm_file_npy = os.path.join(self.intermediate_path, desc_json["ofm_file"][0]) - # Load the output file (OFM) and return it as a numpy array - tosa_ref_output = np.load(ofm_file_npy) + tosa_ref_outputs = [] + for ofm_file in desc_json["ofm_file"]: + ofm_file_npy = os.path.join(self.intermediate_path, ofm_file) - if self.is_quantized: - # Need to dequant back to FP32 for comparison with torch output - quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + # Load the output file (OFM) and return it as a numpy array + tosa_ref_output = np.load(ofm_file_npy) - # tosa_output is a numpy array, convert to torch tensor for comparison - tosa_ref_output = torch.from_numpy(tosa_ref_output.astype("float32")) + if self.is_quantized: + # Need to dequant back to FP32 for comparison with torch output + quant_param = self.qp_output + assert ( + quant_param is not None + ), "There are no quantization parameters, check output parameters" + tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale - return tosa_ref_output + # tosa_output is a numpy array, convert to torch tensor for comparison + tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32"))) + + return tosa_ref_outputs def prep_data_for_save( diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index be5ea7dd71..41fc907fdf 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -260,7 +260,7 @@ def run_method_and_compare_outputs( print(f"Run {run_iteration} with input shapes: {input_shapes}") reference_output = reference_stage.run_artifact(reference_input) - test_output = (test_stage.run_artifact(test_input),) + test_output = tuple(test_stage.run_artifact(test_input)) if is_nhwc: test_output = self.transpose_data_format(test_output, "NCHW")