Skip to content

Commit

Permalink
Add support for upsample_nearest2d op in the Arm backend
Browse files Browse the repository at this point in the history
Change-Id: Id0b742214e5432957b2f573b4218f09a4d9734e4
  • Loading branch information
Tessil authored and freddan80 committed Nov 13, 2024
1 parent 667f600 commit 4c68660
Show file tree
Hide file tree
Showing 9 changed files with 354 additions and 1 deletion.
2 changes: 2 additions & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
Expand Down Expand Up @@ -144,5 +145,6 @@ def ops_to_not_decompose(
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_to_not_decompose, None)
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
op_tanh,
op_transpose,
op_unsqueeze,
op_upsample_nearest2d,
op_view,
)
68 changes: 68 additions & 0 deletions backends/arm/operators/op_upsample_nearest2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
from serializer.tosa_serializer import TosaOp

from tosa.ResizeMode import ResizeMode


@register_node_visitor
class UpsampleNearest2dVisitor(NodeVisitor):
target = "aten.upsample_nearest2d.vec"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
)
# Ignore scale and size parameters, directly use the output size as
# we only support static shapes currently
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])

scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
)

def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
offset=offset_yx.tolist(),
border=border_yx.tolist(),
mode=ResizeMode.NEAREST,
)

tosa_graph.addOperator(
TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class ArmQuantizer(Quantizer):
"mm",
"one_to_one",
"generic",
"upsample_nearest2d",
]

def __init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ def decorator(annotator: AnnotatorType):
mul_annotator,
one_to_one_annotator,
sub_annotator,
upsample_nearest2d_annotator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None):
def filter(node: Node):
is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec
if filter_fn is None:
return is_upsample
else:
return is_upsample and filter_fn(node)

return filter


@register_annotator("upsample_nearest2d")
def _annotate_upsample_nearest2d(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
module_partitions = get_source_partitions(
gm.graph,
[
torch.nn.UpsamplingNearest2d,
torch.nn.Upsample,
torch.nn.functional.interpolate,
],
_filter_upsample_nearest2d(filter_fn),
)
upsample_partitions = list(
itertools.chain.from_iterable(module_partitions.values())
)
annotated_partitions = []

for upsample_partition in upsample_partitions:
annotated_partitions.append(upsample_partition.nodes)

assert len(upsample_partition.nodes) == 1
upsample_node = upsample_partition.nodes[0]

input_act = upsample_node.args[0]
assert isinstance(input_act, Node)

input_act_qspec = quantization_config.get_input_act_qspec()
output_act_qspec = SharedQuantizationSpec((input_act, upsample_node))

upsample_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: input_act_qspec,
},
output_qspec=output_act_qspec,
_annotated=True,
)

return annotated_partitions
165 changes: 165 additions & 0 deletions backends/arm/test/ops/test_upsample_nearest2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Optional, Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data, size, scale_factor, compare_outputs)
("rand_double_scale", torch.rand(2, 4, 8, 3), None, 2.0, True),
("rand_double_scale_one_dim", torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
("rand_double_size", torch.rand(2, 4, 8, 3), (16, 6), None, True),
("rand_one_double_scale", torch.rand(2, 4, 1, 1), None, 2.0, True),
("rand_one_double_size", torch.rand(2, 4, 1, 1), (2, 2), None, True),
("rand_one_same_scale", torch.rand(2, 4, 1, 1), None, 1.0, True),
("rand_one_same_size", torch.rand(2, 4, 1, 1), (1, 1), None, True),
# Can't compare outputs as the rounding when selecting the nearest pixel is
# different between PyTorch and TOSA. Just check the legalization went well.
# TODO Improve the test infrastructure to support more in depth verification
# of the TOSA legalization results.
("rand_half_scale", torch.rand(2, 4, 8, 6), None, 0.5, False),
("rand_half_size", torch.rand(2, 4, 8, 6), (4, 3), None, False),
("rand_one_and_half_scale", torch.rand(2, 4, 8, 3), None, 1.5, False),
("rand_one_and_half_size", torch.rand(2, 4, 8, 3), (12, 4), None, False),
]


class TestUpsampleNearest2d(unittest.TestCase):
class UpsamplingNearest2d(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.UpsamplingNearest2d( # noqa: TOR101
size=size, scale_factor=scale_factor
)

def forward(self, x):
return self.upsample(x)

class Upsample(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.Upsample(
size=size, scale_factor=scale_factor, mode="nearest"
)

def forward(self, x):
return self.upsample(x)

class Interpolate(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = lambda x: torch.nn.functional.interpolate(
x, size=size, scale_factor=scale_factor, mode="nearest"
)

def forward(self, x):
return self.upsample(x)

def _test_upsample_nearest_2d_tosa_MI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
compare_outputs: bool,
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.check(["torch.ops.aten.upsample_nearest2d.vec"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

if compare_outputs:
tester.run_method_and_compare_outputs(inputs=test_data)

def _test_upsample_nearest_2d_tosa_BI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
compare_outputs: bool,
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
)
.quantize()
.export()
.check(["torch.ops.aten.upsample_nearest2d.vec"])
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

if compare_outputs:
tester.run_method_and_compare_outputs(inputs=test_data)

@parameterized.expand(test_data_suite)
def test_upsample_nearest_2d_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
compare_outputs: bool,
):
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.Upsample(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
)

@parameterized.expand(test_data_suite)
def test_upsample_nearest_2d_tosa_BI(
self,
test_name: str,
test_data: torch.Tensor,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
compare_outputs: bool,
):
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.Upsample(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
)
1 change: 0 additions & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def run_method_and_compare_outputs(
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
The default is random data.
"""

edge_stage = self.stages[self.stage_name(tester.ToEdge)]
if edge_stage is None:
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
Expand Down
45 changes: 45 additions & 0 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,48 @@ def expand_dims(
build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)

return intermediate


def get_resize_parameters(
input_size: torch.Tensor,
output_size: torch.Tensor,
resize_mode: int,
align_corners: bool,
):
"""Get the tosa.resize parameters based on the input and output size.
Args:
input_size (torch.Tensor): Size of the input
output_size (torch.Tensor): Size of the output
resize_mode (tosa.ResizeMode): The TOSA resize mode
align_corners (bool): Align the corners pixels of the input and output
Returns:
scale_n (torch.Tensor), scale_d (torch.Tensor),
offset (torch.Tensor), border (torch.Tensor)
"""
assert torch.all(input_size > 0)
assert torch.all(output_size > 0)

scale_n = torch.tensor(
[
so - 1 if align_corners and si > 1 and so > 1 else so
for si, so in zip(input_size, output_size)
]
)
scale_d = torch.tensor(
[
si - 1 if align_corners and si > 1 and so > 1 else si
for si, so in zip(input_size, output_size)
]
)

gcd = torch.gcd(scale_n, scale_d)
scale_n = scale_n // gcd
scale_d = scale_d // gcd

# No half-pixel centre support in PyTorch, no offset needed
offset = torch.zeros_like(input_size)
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset

return scale_n, scale_d, offset, border

0 comments on commit 4c68660

Please sign in to comment.