Skip to content

Commit 6115ce4

Browse files
authored
Add hardswish operator to Arm backend (#8136)
Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent 0a936e0 commit 6115ce4

File tree

5 files changed

+133
-1
lines changed

5 files changed

+133
-1
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class InsertTableOpsPass(ExportPass):
3131
"""
3232
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
3333
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
34-
When loweringthe _table node target_str will be used to find the corresponding torch operator
34+
When lowering the _table node target_str will be used to find the corresponding torch operator
3535
which will be used to produce the table values in operators/op_table.py.
3636
"""
3737

@@ -43,6 +43,7 @@ class InsertTableOpsPass(ExportPass):
4343
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
4444
exir_ops.edge.aten.tanh.default: torch.tanh,
4545
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
46+
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
4647
}
4748

4849
def __init__(self, exported_program: ExportedProgram) -> None:

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def ops_to_not_decompose(
115115
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
116116
ops_to_not_decompose_if_quant_op = [
117117
torch.ops.aten.hardsigmoid.default,
118+
torch.ops.aten.hardswish.default,
118119
]
119120

120121
def filter_fn(node: torch.fx.Node) -> bool:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
8181
exir_ops.edge.aten.permute_copy.default,
8282
exir_ops.edge.aten.hardsigmoid.default,
8383
exir_ops.edge.aten.hardtanh.default,
84+
exir_ops.edge.aten.hardswish.default,
8485
exir_ops.edge.aten.convolution.default,
8586
exir_ops.edge.aten.div.Tensor,
8687
exir_ops.edge.aten.eq.Tensor,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _match_pattern(
133133
torch.ops.aten.tanh.default,
134134
torch.ops.aten.sum.dim_IntList,
135135
torch.ops.aten.hardsigmoid.default,
136+
torch.ops.aten.hardswish.default,
136137
]
137138

138139
_one_to_one_shared_input_qspec = [
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple
10+
11+
import pytest
12+
import torch
13+
14+
from executorch.backends.arm.test import common, conftest
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from parameterized import parameterized
18+
19+
20+
test_data_suite = [
21+
# (test_name, test_data)
22+
("zeros", torch.zeros(1, 10, 10, 10)),
23+
("ones", torch.ones(10, 10, 10)),
24+
("rand", torch.rand(10, 10) - 0.5),
25+
("randn_pos", torch.randn(10) + 10),
26+
("randn_neg", torch.randn(10) - 10),
27+
("ramp", torch.arange(-16, 16, 0.2)),
28+
]
29+
30+
31+
class TestHardswish(unittest.TestCase):
32+
class Hardswish(torch.nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.hardswish = torch.nn.Hardswish()
36+
37+
def forward(self, x):
38+
return self.hardswish(x)
39+
40+
def _test_hardswish_tosa_MI_pipeline(
41+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
42+
):
43+
(
44+
ArmTester(
45+
module,
46+
example_inputs=test_data,
47+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
48+
)
49+
.export()
50+
.check(["torch.ops.aten.hardswish.default"])
51+
.check_not(["torch.ops.quantized_decomposed"])
52+
.to_edge_transform_and_lower()
53+
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
54+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
55+
.to_executorch()
56+
.run_method_and_compare_outputs(inputs=test_data)
57+
)
58+
59+
def _test_hardswish_tosa_BI_pipeline(
60+
self, module: torch.nn.Module, test_data: Tuple
61+
):
62+
(
63+
ArmTester(
64+
module,
65+
example_inputs=test_data,
66+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
67+
)
68+
.quantize()
69+
.export()
70+
.check(["torch.ops.aten.hardswish.default"])
71+
.check(["torch.ops.quantized_decomposed"])
72+
.to_edge_transform_and_lower()
73+
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
74+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
75+
.to_executorch()
76+
.run_method_and_compare_outputs(inputs=test_data)
77+
)
78+
79+
def _test_hardswish_tosa_ethos_BI_pipeline(
80+
self,
81+
compile_spec: list[CompileSpec],
82+
module: torch.nn.Module,
83+
test_data: Tuple[torch.tensor],
84+
):
85+
tester = (
86+
ArmTester(
87+
module,
88+
example_inputs=test_data,
89+
compile_spec=compile_spec,
90+
)
91+
.quantize()
92+
.export()
93+
.check_count({"torch.ops.aten.hardswish.default": 1})
94+
.check(["torch.ops.quantized_decomposed"])
95+
.to_edge_transform_and_lower()
96+
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
97+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98+
.to_executorch()
99+
.serialize()
100+
)
101+
if conftest.is_option_enabled("corstone_fvp"):
102+
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
103+
104+
@parameterized.expand(test_data_suite)
105+
def test_hardswish_tosa_MI(
106+
self,
107+
test_name: str,
108+
test_data: torch.Tensor,
109+
):
110+
self._test_hardswish_tosa_MI_pipeline(self.Hardswish(), (test_data,))
111+
112+
@parameterized.expand(test_data_suite)
113+
def test_hardswish_tosa_BI(self, test_name: str, test_data: torch.Tensor):
114+
self._test_hardswish_tosa_BI_pipeline(self.Hardswish(), (test_data,))
115+
116+
@parameterized.expand(test_data_suite)
117+
@pytest.mark.corstone_fvp
118+
def test_hardswish_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
119+
self._test_hardswish_tosa_ethos_BI_pipeline(
120+
common.get_u55_compile_spec(), self.Hardswish(), (test_data,)
121+
)
122+
123+
@parameterized.expand(test_data_suite)
124+
@pytest.mark.corstone_fvp
125+
def test_hardswish_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
126+
self._test_hardswish_tosa_ethos_BI_pipeline(
127+
common.get_u85_compile_spec(), self.Hardswish(), (test_data,)
128+
)

0 commit comments

Comments
 (0)