Skip to content

Commit 732b518

Browse files
3l1facebook-github-bot
authored andcommitted
Enable int16 rsqrt on Ethos-U55/U85 (#14770)
Summary: Fix Rsqrt op for int16 Add unit tests bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: digantdesai Differential Revision: D83802158
1 parent 3bfd5e0 commit 732b518

File tree

3 files changed

+105
-3
lines changed

3 files changed

+105
-3
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def is_node_supported( # noqa: C901
114114
return False
115115

116116
if node.target in self.target_ops_i8:
117-
if dtype not in (torch.int8,):
117+
if dtype not in (torch.int8, torch.int16):
118118
self.reporter.report_reject(
119119
node, f"Unsupported dtype {dtype} (Supports i8)."
120120
)

backends/arm/test/ops/test_rsqrt.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

13-
from executorch.backends.arm.test import common
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
1620
EthosU85PipelineINT,
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
21-
25+
from executorch.backends.arm.tosa import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2227

2328
aten_op = "torch.ops.aten.rsqrt.default"
2429
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,3 +109,99 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109
tosa_version="TOSA-1.0+INT",
105110
)
106111
pipeline.run()
112+
113+
114+
def get_symmetric_a16w8_rsqrt_quantizer(
115+
u55_config=False, per_channel_quantization=False
116+
):
117+
tosa_version = conftest.get_option("tosa_version")
118+
tosa_profiles = {
119+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
120+
}
121+
122+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
123+
quantizer.set_global(
124+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
125+
)
126+
quantizer.set_module_type(
127+
torch.nn.Linear,
128+
get_symmetric_a16w8_quantization_config(
129+
is_per_channel=per_channel_quantization
130+
),
131+
)
132+
133+
return Quantize(
134+
quantizer,
135+
get_symmetric_a16w8_quantization_config(
136+
is_per_channel=per_channel_quantization
137+
),
138+
)
139+
140+
141+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
142+
def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor):
143+
"""Test rsqrt operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
144+
# Create pipeline with custom 16A8W quantization config
145+
pipeline = TosaPipelineINT[input_t1](
146+
Rsqrt(),
147+
test_tensor(),
148+
aten_op,
149+
exir_op=[],
150+
per_channel_quantization=False,
151+
use_to_edge_transform_and_lower=True,
152+
tosa_extensions=["int16"],
153+
)
154+
155+
pipeline.change_args(
156+
"quantize",
157+
get_symmetric_a16w8_rsqrt_quantizer(
158+
per_channel_quantization=False
159+
),
160+
)
161+
# Run the pipeline
162+
pipeline.run()
163+
164+
165+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
166+
@common.XfailIfNoCorstone300
167+
def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor):
168+
"""Test rsqrt operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
169+
pipeline = EthosU55PipelineINT[input_t1](
170+
Rsqrt(),
171+
test_tensor(),
172+
aten_op,
173+
exir_ops=[],
174+
per_channel_quantization=True,
175+
use_to_edge_transform_and_lower=True,
176+
run_on_fvp=True,
177+
)
178+
179+
pipeline.change_args(
180+
"quantize",
181+
get_symmetric_a16w8_rsqrt_quantizer(
182+
per_channel_quantization=True
183+
),
184+
)
185+
pipeline.run()
186+
187+
188+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
189+
@common.XfailIfNoCorstone320
190+
def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor):
191+
"""Test rsqrt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
192+
pipeline = EthosU85PipelineINT[input_t1](
193+
Rsqrt(),
194+
test_tensor(),
195+
aten_op,
196+
exir_ops=[],
197+
use_to_edge_transform_and_lower=True,
198+
run_on_fvp=True,
199+
)
200+
201+
pipeline.change_args(
202+
"quantize",
203+
get_symmetric_a16w8_rsqrt_quantizer(
204+
per_channel_quantization=False
205+
),
206+
)
207+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_linear.py",
2222
"ops/test_mul.py",
23+
"ops/test_rsqrt.py",
2324
"ops/test_slice.py",
2425
"ops/test_sigmoid.py",
2526
"ops/test_sub.py",

0 commit comments

Comments
 (0)