Skip to content

Commit

Permalink
Add pass to convert split to many slice
Browse files Browse the repository at this point in the history
Differential Revision: D61211922

Pull Request resolved: pytorch#4562
  • Loading branch information
Erik-Lundell authored Aug 19, 2024
1 parent 4c06907 commit eaf383a
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 18 deletions.
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def define_node(
shape = input_node.shape
dim = dim.number
end = (shape[dim] + end.number) % shape[dim]
if end == 0:
end = shape[dim]
size = end - start.number
assert size > 0
assert size <= shape[dim]
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm.passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -28,6 +31,7 @@ def transform_to_backend_pipeline(
"""Apply passes before transforming program to backend"""
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertSplitToSlicePass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
70 changes: 70 additions & 0 deletions backends/arm/passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.fx
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertSplitToSlicePass(ExportPass):
"""
Replace a split operation with many slice operations.
"""

split_ops = (
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
)
slice = exir_ops.edge.aten.slice_copy.Tensor

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
if node.target not in self.split_ops:
continue

# Get useful variables
split_node = node
input_node = split_node.all_input_nodes[0]
output_nodes = split_node.users.copy()
_, shape, _ = extract_tensor_meta(input_node.meta)
rank = len(shape)
split_lengths = split_node.args[1]
dim = split_node.args[2] if len(split_node.args) > 2 else 0
dim = (dim + rank) % rank

assert (
sum(split_lengths) == shape[dim]
), "Given split lengths don't sum up to the size of the dimension."

# Convert split argument 'split_lengths' to slice arguments start and end.
starts = [0] * len(split_lengths)
ends = [0] * len(split_lengths)
start = 0
end = 0
for i, split_length in enumerate(split_lengths):
end = start + split_length
starts[i] = start
ends[i] = end
start = end

# Output nodes are of type getitem
# Create one slice node for each output node with matching argumetns.
with graph_module.graph.inserting_before(split_node):
for output_node in output_nodes:
index = output_node.args[1]
slice_node = graph.create_node(
"call_function",
self.slice,
(input_node, dim, starts[index], ends[index]),
)
slice_node.meta = split_node.meta.copy()
slice_node.meta["val"] = slice_node.meta["val"][index]
output_node.replace_input_with(split_node, slice_node)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Utility functions for ArmQuantizer
#

import operator
from typing import Callable, cast, List

import torch
Expand Down Expand Up @@ -141,8 +142,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
operator.getitem,
]


Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor):
elif x.dim() == 3:
return x[0:7, 0:1, 0:8]
elif x.dim() == 4:
return x[:, 2:5, 3:5, 4:5]
return x[:, 2:5, 3:5, 4:10]

def _test_slice_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: torch.Tensor
Expand Down
139 changes: 139 additions & 0 deletions backends/arm/test/ops/test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from parameterized import parameterized

test_data_t = tuple[torch.Tensor, int | list[int], int]


class TestSimpleSplit(unittest.TestCase):
class Split(torch.nn.Module):

test_data: list[tuple[test_data_t]] = [
((torch.rand(10), 2, 0),),
((torch.rand(10, 10), 3, 1),),
((torch.rand(10, 10), 4, -1),),
((torch.rand(10, 15, 10), [2, 2, 11], 1),),
((torch.rand(4, 4, 4, 4), 2, 0),),
((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),),
]

def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)

class SplitWithSizes(torch.nn.Module):
def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int):
return x.split_with_sizes(split_sizes=split_sizes, dim=dim)

class SplitSingleOut(torch.nn.Module):
def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)[1]

class SplitTwoOut(torch.nn.Module):
def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]

def _test_split_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.to_edge()
.check(
[
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default"
]
)
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_split_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):

quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_split_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check(["torch.ops.aten.split.Tensor"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(Split.test_data)
def test_split_tosa_MI(self, test_data: test_data_t):
self._test_split_tosa_MI_pipeline(self.Split(), test_data)

@parameterized.expand([Split.test_data[3], Split.test_data[5]])
def test_split_with_sizes_tosa_MI(self, test_data: test_data_t):
assert isinstance(test_data[1], list)
self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data)

@parameterized.expand(Split.test_data)
def test_split_n_out_tosa_MI(self, test_data: test_data_t):
self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data)
self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data)

@parameterized.expand(Split.test_data)
def test_split_tosa_BI(self, test_data: test_data_t):
self._test_split_tosa_BI_pipeline(self.Split(), test_data)

# Fails during Vela compilation when trying to use a Tuple as a Named tuple,
# Could be Vela Issue, wait until Regor.
@parameterized.expand(Split.test_data)
@unittest.expectedFailure
def test_split_u55_BI(self, test_data: test_data_t):
self._test_split_u55_BI_pipeline(self.Split(), test_data)
35 changes: 19 additions & 16 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def set_timeout(self, timeout: int):
def run_corstone300(
self,
inputs: Tuple[torch.Tensor],
) -> torch.Tensor:
) -> list[torch.Tensor]:

assert (
self._has_init_run
Expand Down Expand Up @@ -268,12 +268,12 @@ def run_corstone300(

tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(inputs[0].shape)
return tosa_ref_output
return [tosa_ref_output]

def run_tosa_ref_model(
self,
inputs: Tuple[torch.Tensor],
) -> torch.Tensor:
) -> list[torch.Tensor]:
"""
Run TOSA reference model using the tosa_refence_model program.
Expand Down Expand Up @@ -369,23 +369,26 @@ def run_tosa_ref_model(
# Load desc.json, just to get the name of the output file above
with open(desc_file_path) as f:
desc_json = json.load(f)
ofm_file_npy = os.path.join(self.intermediate_path, desc_json["ofm_file"][0])

# Load the output file (OFM) and return it as a numpy array
tosa_ref_output = np.load(ofm_file_npy)
tosa_ref_outputs = []
for ofm_file in desc_json["ofm_file"]:
ofm_file_npy = os.path.join(self.intermediate_path, ofm_file)

if self.is_quantized:
# Need to dequant back to FP32 for comparison with torch output
quant_param = self.qp_output
assert (
quant_param is not None
), "There are no quantization parameters, check output parameters"
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
# Load the output file (OFM) and return it as a numpy array
tosa_ref_output = np.load(ofm_file_npy)

# tosa_output is a numpy array, convert to torch tensor for comparison
tosa_ref_output = torch.from_numpy(tosa_ref_output.astype("float32"))
if self.is_quantized:
# Need to dequant back to FP32 for comparison with torch output
quant_param = self.qp_output
assert (
quant_param is not None
), "There are no quantization parameters, check output parameters"
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale

return tosa_ref_output
# tosa_output is a numpy array, convert to torch tensor for comparison
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))

return tosa_ref_outputs


def prep_data_for_save(
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def run_method_and_compare_outputs(
print(f"Run {run_iteration} with input shapes: {input_shapes}")

reference_output = reference_stage.run_artifact(reference_input)
test_output = (test_stage.run_artifact(test_input),)
test_output = tuple(test_stage.run_artifact(test_input))
if is_nhwc:
test_output = self.transpose_data_format(test_output, "NCHW")

Expand Down

0 comments on commit eaf383a

Please sign in to comment.