Skip to content

Add hardswish operator to Arm backend #8136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class InsertTableOpsPass(ExportPass):
"""
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
When loweringthe _table node target_str will be used to find the corresponding torch operator
When lowering the _table node target_str will be used to find the corresponding torch operator
which will be used to produce the table values in operators/op_table.py.
"""

Expand All @@ -43,6 +43,7 @@ class InsertTableOpsPass(ExportPass):
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
}

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def ops_to_not_decompose(
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose_if_quant_op = [
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
]

def filter_fn(node: torch.fx.Node) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardsigmoid.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.eq.Tensor,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _match_pattern(
torch.ops.aten.tanh.default,
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
128 changes: 128 additions & 0 deletions backends/arm/test/ops/test_hardswish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import pytest
import torch

from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data)
("zeros", torch.zeros(1, 10, 10, 10)),
("ones", torch.ones(10, 10, 10)),
("rand", torch.rand(10, 10) - 0.5),
("randn_pos", torch.randn(10) + 10),
("randn_neg", torch.randn(10) - 10),
("ramp", torch.arange(-16, 16, 0.2)),
]


class TestHardswish(unittest.TestCase):
class Hardswish(torch.nn.Module):
def __init__(self):
super().__init__()
self.hardswish = torch.nn.Hardswish()

def forward(self, x):
return self.hardswish(x)

def _test_hardswish_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check(["torch.ops.aten.hardswish.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardswish_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check(["torch.ops.aten.hardswish.default"])
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardswish_tosa_ethos_BI_pipeline(
self,
compile_spec: list[CompileSpec],
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.hardswish.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)

@parameterized.expand(test_data_suite)
def test_hardswish_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
):
self._test_hardswish_tosa_MI_pipeline(self.Hardswish(), (test_data,))

@parameterized.expand(test_data_suite)
def test_hardswish_tosa_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardswish_tosa_BI_pipeline(self.Hardswish(), (test_data,))

@parameterized.expand(test_data_suite)
@pytest.mark.corstone_fvp
def test_hardswish_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardswish_tosa_ethos_BI_pipeline(
common.get_u55_compile_spec(), self.Hardswish(), (test_data,)
)

@parameterized.expand(test_data_suite)
@pytest.mark.corstone_fvp
def test_hardswish_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardswish_tosa_ethos_BI_pipeline(
common.get_u85_compile_spec(), self.Hardswish(), (test_data,)
)
Loading