Skip to content

Commit 3a754b3

Browse files
3l1meta-codesync[bot]
authored andcommitted
Enable int16 rsqrt on Ethos-U55/U85 (#14770)
Summary: Pull Request resolved: #14770 Fix Rsqrt op for int16 Add unit tests bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: Ninja91, digantdesai Differential Revision: D83802158
1 parent 5467a4d commit 3a754b3

File tree

3 files changed

+125
-11
lines changed

3 files changed

+125
-11
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,27 @@ def f(x: torch.Tensor) -> torch.Tensor:
185185
)
186186
# Dont use the 7 LSBs.
187187
x = in_quantargs.dequantize_value((x & ~0x7F))
188+
# x = in_quantargs.dequantize_value(x) // (1 << 7)
188189
x = torch_op(x)
190+
# x = x * (1 << 7)
189191
return out_quantargs.quantize_value(x)
190192

191-
lut_values = f(
192-
torch.linspace(
193-
start=in_quantargs.qmin,
194-
end=in_quantargs.qmax + 1,
195-
steps=513,
196-
# use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1.
197-
dtype=torch.int32,
198-
)
193+
# Create the 9.7 fixed-point value
194+
r = torch.linspace(
195+
start=in_quantargs.qmin,
196+
end=in_quantargs.qmax + 1,
197+
steps=513,
198+
# use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1.
199+
dtype=torch.int32,
199200
)
201+
# # Cast input to a wider type (int32)
202+
# r_int32 = r.to(torch.int32)
203+
# # Extract most significant 9 bits
204+
# index = (r_int32 >> 7) & 0x1FF
205+
# # Extract the fractional 7 bits
206+
# fraction = r_int32 & 0x7F
207+
208+
lut_values = f(r)
200209
# Calculate how much we need to shift table values to fit in 16 signed bits
201210
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
202211
# Example:

backends/arm/test/ops/test_rsqrt.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

13-
from executorch.backends.arm.test import common
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
1620
EthosU85PipelineINT,
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
21-
25+
from executorch.backends.arm.tosa import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2227

2328
aten_op = "torch.ops.aten.rsqrt.default"
2429
input_t1 = Tuple[torch.Tensor] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934
"ones_4d": lambda: (torch.ones(1, 10, 10, 10),),
3035
"rand_4d_1": lambda: (torch.rand(1, 10, 10, 10),),
3136
"rand_4d_2": lambda: (torch.rand(1, 5, 10, 20),),
32-
"rand_3d": lambda: (torch.rand(5, 10, 20),),
37+
"rand_3d": lambda: (torch.rand(5, 10, 20) + 1.0,),
3338
}
3439

3540
def forward(self, x: torch.Tensor):
@@ -104,3 +109,102 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109
tosa_version="TOSA-1.0+INT",
105110
)
106111
pipeline.run()
112+
113+
114+
def get_symmetric_a16w8_rsqrt_quantizer(
115+
u55_config=False, per_channel_quantization=False
116+
):
117+
tosa_version = conftest.get_option("tosa_version")
118+
tosa_profiles = {
119+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
120+
}
121+
122+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
123+
quantizer.set_global(
124+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
125+
)
126+
quantizer.set_module_type(
127+
torch.nn.Linear,
128+
get_symmetric_a16w8_quantization_config(
129+
is_per_channel=per_channel_quantization
130+
),
131+
)
132+
133+
return Quantize(
134+
quantizer,
135+
get_symmetric_a16w8_quantization_config(
136+
is_per_channel=per_channel_quantization
137+
),
138+
)
139+
140+
141+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
142+
def test_rsqrt_int16_tosa_INT(test_tensor: torch.Tensor):
143+
"""Test rsqrt operation with int16 quantization"""
144+
pipeline = TosaPipelineINT[input_t1](
145+
Rsqrt(),
146+
test_tensor(),
147+
aten_op,
148+
exir_op=[],
149+
per_channel_quantization=False,
150+
use_to_edge_transform_and_lower=True,
151+
tosa_extensions=["int16"],
152+
)
153+
154+
pipeline.change_args(
155+
"quantize",
156+
get_symmetric_a16w8_rsqrt_quantizer(
157+
per_channel_quantization=False
158+
),
159+
)
160+
# Run the pipeline
161+
pipeline.run()
162+
163+
164+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
165+
@common.XfailIfNoCorstone300
166+
def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
167+
"""Test rsqrt operation with int16 quantization on U55"""
168+
pipeline = EthosU55PipelineINT[input_t1](
169+
Rsqrt(),
170+
test_tensor(),
171+
aten_op,
172+
exir_ops=[],
173+
per_channel_quantization=True,
174+
use_to_edge_transform_and_lower=True,
175+
atol=1e-02,
176+
rtol=1e-02,
177+
run_on_fvp=True,
178+
)
179+
180+
pipeline.change_args(
181+
"quantize",
182+
get_symmetric_a16w8_rsqrt_quantizer(
183+
per_channel_quantization=True
184+
),
185+
)
186+
pipeline.run()
187+
188+
189+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
190+
@common.XfailIfNoCorstone320
191+
def test_rsqrt_int16_u85_INT16(test_tensor: torch.Tensor):
192+
"""Test rsqrt operation with int16 quantization on U85"""
193+
pipeline = EthosU85PipelineINT[input_t1](
194+
Rsqrt(),
195+
test_tensor(),
196+
aten_op,
197+
exir_ops=[],
198+
use_to_edge_transform_and_lower=True,
199+
atol=1e-02,
200+
rtol=1e-02,
201+
run_on_fvp=True,
202+
)
203+
204+
pipeline.change_args(
205+
"quantize",
206+
get_symmetric_a16w8_rsqrt_quantizer(
207+
per_channel_quantization=False
208+
),
209+
)
210+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def define_arm_tests():
2424
"ops/test_linear.py",
2525
"ops/test_mul.py",
2626
"ops/test_permute.py",
27+
"ops/test_rsqrt.py",
2728
"ops/test_slice.py",
2829
"ops/test_sigmoid.py",
2930
"ops/test_sub.py",

0 commit comments

Comments
 (0)