99from typing import Tuple
1010
1111import torch
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1217
13- from executorch .backends .arm .test import common
1418from executorch .backends .arm .test .tester .test_pipeline import (
1519 EthosU55PipelineINT ,
1620 EthosU85PipelineINT ,
1721 TosaPipelineFP ,
1822 TosaPipelineINT ,
1923 VgfPipeline ,
2024)
21-
25+ from executorch .backends .arm .tosa import TosaSpecification
26+ from executorch .backends .xnnpack .test .tester import Quantize
2227
2328aten_op = "torch.ops.aten.rsqrt.default"
2429input_t1 = Tuple [torch .Tensor ] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934 "ones_4d" : lambda : (torch .ones (1 , 10 , 10 , 10 ),),
3035 "rand_4d_1" : lambda : (torch .rand (1 , 10 , 10 , 10 ),),
3136 "rand_4d_2" : lambda : (torch .rand (1 , 5 , 10 , 20 ),),
32- "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ),),
37+ "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ) + 1.0 ,),
3338 }
3439
3540 def forward (self , x : torch .Tensor ):
@@ -104,3 +109,96 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109 tosa_version = "TOSA-1.0+INT" ,
105110 )
106111 pipeline .run ()
112+
113+
114+ def get_symmetric_a16w8_rsqrt_quantizer (
115+ u55_config = False , per_channel_quantization = False
116+ ):
117+ tosa_version = conftest .get_option ("tosa_version" )
118+ tosa_profiles = {
119+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
120+ }
121+
122+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
123+ quantizer .set_global (
124+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
125+ )
126+
127+ return Quantize (
128+ quantizer ,
129+ get_symmetric_a16w8_quantization_config (
130+ is_per_channel = per_channel_quantization
131+ ),
132+ )
133+
134+
135+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
136+ def test_rsqrt_int16_tosa_INT (test_tensor : torch .Tensor ):
137+ """Test rsqrt operation with int16 quantization"""
138+ pipeline = TosaPipelineINT [input_t1 ](
139+ Rsqrt (),
140+ test_tensor (),
141+ aten_op ,
142+ exir_op = [],
143+ per_channel_quantization = False ,
144+ use_to_edge_transform_and_lower = True ,
145+ tosa_extensions = ["int16" ],
146+ )
147+
148+ pipeline .change_args (
149+ "quantize" ,
150+ get_symmetric_a16w8_rsqrt_quantizer (
151+ per_channel_quantization = False
152+ ),
153+ )
154+ # Run the pipeline
155+ pipeline .run ()
156+
157+
158+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
159+ @common .XfailIfNoCorstone300
160+ def test_rsqrt_int16_u55_INT16 (test_tensor : torch .Tensor ):
161+ """Test rsqrt operation with int16 quantization on U55"""
162+ pipeline = EthosU55PipelineINT [input_t1 ](
163+ Rsqrt (),
164+ test_tensor (),
165+ aten_op ,
166+ exir_ops = [],
167+ per_channel_quantization = True ,
168+ use_to_edge_transform_and_lower = True ,
169+ atol = 1e-03 ,
170+ rtol = 1e-03 ,
171+ run_on_fvp = True ,
172+ )
173+
174+ pipeline .change_args (
175+ "quantize" ,
176+ get_symmetric_a16w8_rsqrt_quantizer (
177+ per_channel_quantization = True
178+ ),
179+ )
180+ pipeline .run ()
181+
182+
183+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
184+ @common .XfailIfNoCorstone320
185+ def test_rsqrt_int16_u85_INT16 (test_tensor : torch .Tensor ):
186+ """Test rsqrt operation with int16 quantization on U85"""
187+ pipeline = EthosU85PipelineINT [input_t1 ](
188+ Rsqrt (),
189+ test_tensor (),
190+ aten_op ,
191+ exir_ops = [],
192+ use_to_edge_transform_and_lower = True ,
193+ atol = 1e-03 ,
194+ rtol = 1e-03 ,
195+ run_on_fvp = True ,
196+ )
197+
198+ pipeline .change_args (
199+ "quantize" ,
200+ get_symmetric_a16w8_rsqrt_quantizer (
201+ per_channel_quantization = False
202+ ),
203+ )
204+ pipeline .run ()
0 commit comments