From 4c686608a389713f49f6b01dcb31a00fee91bab9 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Tue, 27 Aug 2024 17:16:54 +0100 Subject: [PATCH] Add support for upsample_nearest2d op in the Arm backend Change-Id: Id0b742214e5432957b2f573b4218f09a4d9734e4 --- backends/arm/arm_partitioner.py | 2 + backends/arm/operators/__init__.py | 1 + .../arm/operators/op_upsample_nearest2d.py | 68 ++++++++ backends/arm/quantizer/arm_quantizer.py | 1 + .../quantization_annotation/__init__.py | 1 + .../upsample_nearest2d_annotator.py | 71 ++++++++ .../arm/test/ops/test_upsample_nearest2d.py | 165 ++++++++++++++++++ backends/arm/test/tester/arm_tester.py | 1 - backends/arm/tosa_utils.py | 45 +++++ 9 files changed, 354 insertions(+), 1 deletion(-) create mode 100644 backends/arm/operators/op_upsample_nearest2d.py create mode 100644 backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py create mode 100644 backends/arm/test/ops/test_upsample_nearest2d.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index ef924fa434..f4050351d1 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -69,6 +69,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, @@ -144,5 +145,6 @@ def ops_to_not_decompose( ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: ops_to_not_decompose = [ torch.ops.aten.linear.default, + torch.ops.aten.upsample_nearest2d.vec, ] return (ops_to_not_decompose, None) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 6e51c2c141..988765990d 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -37,5 +37,6 @@ op_tanh, op_transpose, op_unsqueeze, + op_upsample_nearest2d, op_view, ) diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py new file mode 100644 index 0000000000..c6c0423a1a --- /dev/null +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -0,0 +1,68 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape +from serializer.tosa_serializer import TosaOp + +from tosa.ResizeMode import ResizeMode + + +@register_node_visitor +class UpsampleNearest2dVisitor(NodeVisitor): + target = "aten.upsample_nearest2d.vec" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert ( + inputs[0].shape is not None and output.shape is not None + ), "Only static shapes are supported" + + # tosa_shape output is NHWC, take HW + input_size_yx = torch.tensor( + tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] + ) + # Ignore scale and size parameters, directly use the output size as + # we only support static shapes currently + output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) + + scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( + input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True + ) + + def in_int16_range(x): + return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) + + assert in_int16_range(scale_n_yx) + assert in_int16_range(scale_d_yx) + assert in_int16_range(border_yx) + + attr = ts.TosaSerializerAttribute() + attr.ResizeAttribute( + scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]], + offset=offset_yx.tolist(), + border=border_yx.tolist(), + mode=ResizeMode.NEAREST, + ) + + tosa_graph.addOperator( + TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 38fce85de4..6f2a5689d3 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -270,6 +270,7 @@ class ArmQuantizer(Quantizer): "mm", "one_to_one", "generic", + "upsample_nearest2d", ] def __init__(self) -> None: diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 053e2a4c29..1201df51ad 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -59,4 +59,5 @@ def decorator(annotator: AnnotatorType): mul_annotator, one_to_one_annotator, sub_annotator, + upsample_nearest2d_annotator, ) diff --git a/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py b/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py new file mode 100644 index 0000000000..9d73da5bc0 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py @@ -0,0 +1,71 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, +) +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None): + def filter(node: Node): + is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec + if filter_fn is None: + return is_upsample + else: + return is_upsample and filter_fn(node) + + return filter + + +@register_annotator("upsample_nearest2d") +def _annotate_upsample_nearest2d( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + module_partitions = get_source_partitions( + gm.graph, + [ + torch.nn.UpsamplingNearest2d, + torch.nn.Upsample, + torch.nn.functional.interpolate, + ], + _filter_upsample_nearest2d(filter_fn), + ) + upsample_partitions = list( + itertools.chain.from_iterable(module_partitions.values()) + ) + annotated_partitions = [] + + for upsample_partition in upsample_partitions: + annotated_partitions.append(upsample_partition.nodes) + + assert len(upsample_partition.nodes) == 1 + upsample_node = upsample_partition.nodes[0] + + input_act = upsample_node.args[0] + assert isinstance(input_act, Node) + + input_act_qspec = quantization_config.get_input_act_qspec() + output_act_qspec = SharedQuantizationSpec((input_act, upsample_node)) + + upsample_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + + return annotated_partitions diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py new file mode 100644 index 0000000000..d03ac1e441 --- /dev/null +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -0,0 +1,165 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Optional, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data, size, scale_factor, compare_outputs) + ("rand_double_scale", torch.rand(2, 4, 8, 3), None, 2.0, True), + ("rand_double_scale_one_dim", torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True), + ("rand_double_size", torch.rand(2, 4, 8, 3), (16, 6), None, True), + ("rand_one_double_scale", torch.rand(2, 4, 1, 1), None, 2.0, True), + ("rand_one_double_size", torch.rand(2, 4, 1, 1), (2, 2), None, True), + ("rand_one_same_scale", torch.rand(2, 4, 1, 1), None, 1.0, True), + ("rand_one_same_size", torch.rand(2, 4, 1, 1), (1, 1), None, True), + # Can't compare outputs as the rounding when selecting the nearest pixel is + # different between PyTorch and TOSA. Just check the legalization went well. + # TODO Improve the test infrastructure to support more in depth verification + # of the TOSA legalization results. + ("rand_half_scale", torch.rand(2, 4, 8, 6), None, 0.5, False), + ("rand_half_size", torch.rand(2, 4, 8, 6), (4, 3), None, False), + ("rand_one_and_half_scale", torch.rand(2, 4, 8, 3), None, 1.5, False), + ("rand_one_and_half_size", torch.rand(2, 4, 8, 3), (12, 4), None, False), +] + + +class TestUpsampleNearest2d(unittest.TestCase): + class UpsamplingNearest2d(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = torch.nn.UpsamplingNearest2d( # noqa: TOR101 + size=size, scale_factor=scale_factor + ) + + def forward(self, x): + return self.upsample(x) + + class Upsample(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = torch.nn.Upsample( + size=size, scale_factor=scale_factor, mode="nearest" + ) + + def forward(self, x): + return self.upsample(x) + + class Interpolate(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = lambda x: torch.nn.functional.interpolate( + x, size=size, scale_factor=scale_factor, mode="nearest" + ) + + def forward(self, x): + return self.upsample(x) + + def _test_upsample_nearest_2d_tosa_MI_pipeline( + self, + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + compare_outputs: bool, + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .check(["torch.ops.aten.upsample_nearest2d.vec"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_not(["torch.ops.aten.upsample_nearest2d.vec"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + if compare_outputs: + tester.run_method_and_compare_outputs(inputs=test_data) + + def _test_upsample_nearest_2d_tosa_BI_pipeline( + self, + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + compare_outputs: bool, + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .check(["torch.ops.aten.upsample_nearest2d.vec"]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_not(["torch.ops.aten.upsample_nearest2d.vec"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + if compare_outputs: + tester.run_method_and_compare_outputs(inputs=test_data) + + @parameterized.expand(test_data_suite) + def test_upsample_nearest_2d_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + compare_outputs: bool, + ): + self._test_upsample_nearest_2d_tosa_MI_pipeline( + self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs + ) + self._test_upsample_nearest_2d_tosa_MI_pipeline( + self.Upsample(size, scale_factor), (test_data,), compare_outputs + ) + self._test_upsample_nearest_2d_tosa_MI_pipeline( + self.Interpolate(size, scale_factor), (test_data,), compare_outputs + ) + + @parameterized.expand(test_data_suite) + def test_upsample_nearest_2d_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + compare_outputs: bool, + ): + self._test_upsample_nearest_2d_tosa_BI_pipeline( + self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs + ) + self._test_upsample_nearest_2d_tosa_BI_pipeline( + self.Upsample(size, scale_factor), (test_data,), compare_outputs + ) + self._test_upsample_nearest_2d_tosa_BI_pipeline( + self.Interpolate(size, scale_factor), (test_data,), compare_outputs + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 14a9d1df41..e2062f2428 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -287,7 +287,6 @@ def run_method_and_compare_outputs( inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ - edge_stage = self.stages[self.stage_name(tester.ToEdge)] if edge_stage is None: edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index b61b27853a..35ee6ef6b3 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -298,3 +298,48 @@ def expand_dims( build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name) return intermediate + + +def get_resize_parameters( + input_size: torch.Tensor, + output_size: torch.Tensor, + resize_mode: int, + align_corners: bool, +): + """Get the tosa.resize parameters based on the input and output size. + + Args: + input_size (torch.Tensor): Size of the input + output_size (torch.Tensor): Size of the output + resize_mode (tosa.ResizeMode): The TOSA resize mode + align_corners (bool): Align the corners pixels of the input and output + + Returns: + scale_n (torch.Tensor), scale_d (torch.Tensor), + offset (torch.Tensor), border (torch.Tensor) + """ + assert torch.all(input_size > 0) + assert torch.all(output_size > 0) + + scale_n = torch.tensor( + [ + so - 1 if align_corners and si > 1 and so > 1 else so + for si, so in zip(input_size, output_size) + ] + ) + scale_d = torch.tensor( + [ + si - 1 if align_corners and si > 1 and so > 1 else si + for si, so in zip(input_size, output_size) + ] + ) + + gcd = torch.gcd(scale_n, scale_d) + scale_n = scale_n // gcd + scale_d = scale_d // gcd + + # No half-pixel centre support in PyTorch, no offset needed + offset = torch.zeros_like(input_size) + border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset + + return scale_n, scale_d, offset, border