forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pass to convert split to many slice
Differential Revision: D61211922 Pull Request resolved: pytorch#4562
- Loading branch information
1 parent
4c06907
commit eaf383a
Showing
9 changed files
with
241 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch.fx | ||
from executorch.backends.arm.tosa_mapping import extract_tensor_meta | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class ConvertSplitToSlicePass(ExportPass): | ||
""" | ||
Replace a split operation with many slice operations. | ||
""" | ||
|
||
split_ops = ( | ||
exir_ops.edge.aten.split_with_sizes_copy.default, | ||
exir_ops.edge.aten.split_copy.Tensor, | ||
) | ||
slice = exir_ops.edge.aten.slice_copy.Tensor | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
for node in graph.nodes: | ||
if node.target not in self.split_ops: | ||
continue | ||
|
||
# Get useful variables | ||
split_node = node | ||
input_node = split_node.all_input_nodes[0] | ||
output_nodes = split_node.users.copy() | ||
_, shape, _ = extract_tensor_meta(input_node.meta) | ||
rank = len(shape) | ||
split_lengths = split_node.args[1] | ||
dim = split_node.args[2] if len(split_node.args) > 2 else 0 | ||
dim = (dim + rank) % rank | ||
|
||
assert ( | ||
sum(split_lengths) == shape[dim] | ||
), "Given split lengths don't sum up to the size of the dimension." | ||
|
||
# Convert split argument 'split_lengths' to slice arguments start and end. | ||
starts = [0] * len(split_lengths) | ||
ends = [0] * len(split_lengths) | ||
start = 0 | ||
end = 0 | ||
for i, split_length in enumerate(split_lengths): | ||
end = start + split_length | ||
starts[i] = start | ||
ends[i] = end | ||
start = end | ||
|
||
# Output nodes are of type getitem | ||
# Create one slice node for each output node with matching argumetns. | ||
with graph_module.graph.inserting_before(split_node): | ||
for output_node in output_nodes: | ||
index = output_node.args[1] | ||
slice_node = graph.create_node( | ||
"call_function", | ||
self.slice, | ||
(input_node, dim, starts[index], ends[index]), | ||
) | ||
slice_node.meta = split_node.meta.copy() | ||
slice_node.meta["val"] = slice_node.meta["val"][index] | ||
output_node.replace_input_with(split_node, slice_node) | ||
graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
ArmQuantizer, | ||
get_symmetric_quantization_config, | ||
) | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.xnnpack.test.tester.tester import Quantize | ||
from parameterized import parameterized | ||
|
||
test_data_t = tuple[torch.Tensor, int | list[int], int] | ||
|
||
|
||
class TestSimpleSplit(unittest.TestCase): | ||
class Split(torch.nn.Module): | ||
|
||
test_data: list[tuple[test_data_t]] = [ | ||
((torch.rand(10), 2, 0),), | ||
((torch.rand(10, 10), 3, 1),), | ||
((torch.rand(10, 10), 4, -1),), | ||
((torch.rand(10, 15, 10), [2, 2, 11], 1),), | ||
((torch.rand(4, 4, 4, 4), 2, 0),), | ||
((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),), | ||
] | ||
|
||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim) | ||
|
||
class SplitWithSizes(torch.nn.Module): | ||
def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int): | ||
return x.split_with_sizes(split_sizes=split_sizes, dim=dim) | ||
|
||
class SplitSingleOut(torch.nn.Module): | ||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim)[1] | ||
|
||
class SplitTwoOut(torch.nn.Module): | ||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim)[1:3] | ||
|
||
def _test_split_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.to_edge() | ||
.check( | ||
[ | ||
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default" | ||
] | ||
) | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_split_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
|
||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data, qtol=1) | ||
) | ||
|
||
def _test_split_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_u55_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check(["torch.ops.aten.split.Tensor"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_tosa_MI(self, test_data: test_data_t): | ||
self._test_split_tosa_MI_pipeline(self.Split(), test_data) | ||
|
||
@parameterized.expand([Split.test_data[3], Split.test_data[5]]) | ||
def test_split_with_sizes_tosa_MI(self, test_data: test_data_t): | ||
assert isinstance(test_data[1], list) | ||
self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_n_out_tosa_MI(self, test_data: test_data_t): | ||
self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data) | ||
self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_tosa_BI(self, test_data: test_data_t): | ||
self._test_split_tosa_BI_pipeline(self.Split(), test_data) | ||
|
||
# Fails during Vela compilation when trying to use a Tuple as a Named tuple, | ||
# Could be Vela Issue, wait until Regor. | ||
@parameterized.expand(Split.test_data) | ||
@unittest.expectedFailure | ||
def test_split_u55_BI(self, test_data: test_data_t): | ||
self._test_split_u55_BI_pipeline(self.Split(), test_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters