Skip to content

Commit e4569d6

Browse files
3l1facebook-github-bot
authored andcommitted
Enable int16 rsqrt on Ethos-U55/U85 (#14770)
Summary: Fix Rsqrt op for int16 Add unit tests bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: Ninja91, digantdesai Differential Revision: D83802158
1 parent 331d771 commit e4569d6

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed

backends/arm/test/ops/test_rsqrt.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

13-
from executorch.backends.arm.test import common
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
1620
EthosU85PipelineINT,
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
21-
25+
from executorch.backends.arm.tosa import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2227

2328
aten_op = "torch.ops.aten.rsqrt.default"
2429
input_t1 = Tuple[torch.Tensor] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934
"ones_4d": lambda: (torch.ones(1, 10, 10, 10),),
3035
"rand_4d_1": lambda: (torch.rand(1, 10, 10, 10),),
3136
"rand_4d_2": lambda: (torch.rand(1, 5, 10, 20),),
32-
"rand_3d": lambda: (torch.rand(5, 10, 20),),
37+
"rand_3d": lambda: (torch.rand(5, 10, 20) + 1.0,),
3338
}
3439

3540
def forward(self, x: torch.Tensor):
@@ -104,3 +109,96 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109
tosa_version="TOSA-1.0+INT",
105110
)
106111
pipeline.run()
112+
113+
114+
def get_symmetric_a16w8_rsqrt_quantizer(
115+
u55_config=False, per_channel_quantization=False
116+
):
117+
tosa_version = conftest.get_option("tosa_version")
118+
tosa_profiles = {
119+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
120+
}
121+
122+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
123+
quantizer.set_global(
124+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
125+
)
126+
127+
return Quantize(
128+
quantizer,
129+
get_symmetric_a16w8_quantization_config(
130+
is_per_channel=per_channel_quantization
131+
),
132+
)
133+
134+
135+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
136+
def test_rsqrt_int16_tosa_INT(test_tensor: torch.Tensor):
137+
"""Test rsqrt operation with int16 quantization"""
138+
pipeline = TosaPipelineINT[input_t1](
139+
Rsqrt(),
140+
test_tensor(),
141+
aten_op,
142+
exir_op=[],
143+
per_channel_quantization=False,
144+
use_to_edge_transform_and_lower=True,
145+
tosa_extensions=["int16"],
146+
)
147+
148+
pipeline.change_args(
149+
"quantize",
150+
get_symmetric_a16w8_rsqrt_quantizer(
151+
per_channel_quantization=False
152+
),
153+
)
154+
# Run the pipeline
155+
pipeline.run()
156+
157+
158+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
159+
@common.XfailIfNoCorstone300
160+
def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
161+
"""Test rsqrt operation with int16 quantization on U55"""
162+
pipeline = EthosU55PipelineINT[input_t1](
163+
Rsqrt(),
164+
test_tensor(),
165+
aten_op,
166+
exir_ops=[],
167+
per_channel_quantization=True,
168+
use_to_edge_transform_and_lower=True,
169+
atol=1e-03,
170+
rtol=1e-03,
171+
run_on_fvp=True,
172+
)
173+
174+
pipeline.change_args(
175+
"quantize",
176+
get_symmetric_a16w8_rsqrt_quantizer(
177+
per_channel_quantization=True
178+
),
179+
)
180+
pipeline.run()
181+
182+
183+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
184+
@common.XfailIfNoCorstone320
185+
def test_rsqrt_int16_u85_INT16(test_tensor: torch.Tensor):
186+
"""Test rsqrt operation with int16 quantization on U85"""
187+
pipeline = EthosU85PipelineINT[input_t1](
188+
Rsqrt(),
189+
test_tensor(),
190+
aten_op,
191+
exir_ops=[],
192+
use_to_edge_transform_and_lower=True,
193+
atol=1e-03,
194+
rtol=1e-03,
195+
run_on_fvp=True,
196+
)
197+
198+
pipeline.change_args(
199+
"quantize",
200+
get_symmetric_a16w8_rsqrt_quantizer(
201+
per_channel_quantization=False
202+
),
203+
)
204+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def define_arm_tests():
2222
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
25+
"ops/test_rsqrt.py",
2526
"ops/test_slice.py",
2627
"ops/test_sigmoid.py",
2728
"ops/test_sub.py",

0 commit comments

Comments
 (0)