99from typing import Tuple
1010
1111import torch
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1217
13- from executorch .backends .arm .test import common
1418from executorch .backends .arm .test .tester .test_pipeline import (
1519 EthosU55PipelineINT ,
1620 EthosU85PipelineINT ,
1721 TosaPipelineFP ,
1822 TosaPipelineINT ,
1923 VgfPipeline ,
2024)
21-
25+ from executorch .backends .arm .tosa import TosaSpecification
26+ from executorch .backends .xnnpack .test .tester import Quantize
2227
2328aten_op = "torch.ops.aten.rsqrt.default"
2429input_t1 = Tuple [torch .Tensor ] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934 "ones_4d" : lambda : (torch .ones (1 , 10 , 10 , 10 ),),
3035 "rand_4d_1" : lambda : (torch .rand (1 , 10 , 10 , 10 ),),
3136 "rand_4d_2" : lambda : (torch .rand (1 , 5 , 10 , 20 ),),
32- "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ),),
37+ "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ) + 1.0 ,),
3338 }
3439
3540 def forward (self , x : torch .Tensor ):
@@ -104,3 +109,102 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109 tosa_version = "TOSA-1.0+INT" ,
105110 )
106111 pipeline .run ()
112+
113+
114+ def get_symmetric_a16w8_rsqrt_quantizer (
115+ u55_config = False , per_channel_quantization = False
116+ ):
117+ tosa_version = conftest .get_option ("tosa_version" )
118+ tosa_profiles = {
119+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
120+ }
121+
122+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
123+ quantizer .set_global (
124+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
125+ )
126+ quantizer .set_module_type (
127+ torch .nn .Linear ,
128+ get_symmetric_a16w8_quantization_config (
129+ is_per_channel = per_channel_quantization
130+ ),
131+ )
132+
133+ return Quantize (
134+ quantizer ,
135+ get_symmetric_a16w8_quantization_config (
136+ is_per_channel = per_channel_quantization
137+ ),
138+ )
139+
140+
141+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
142+ def test_rsqrt_int16_tosa_INT (test_tensor : torch .Tensor ):
143+ """Test rsqrt operation with int16 quantization"""
144+ pipeline = TosaPipelineINT [input_t1 ](
145+ Rsqrt (),
146+ test_tensor (),
147+ aten_op ,
148+ exir_op = [],
149+ per_channel_quantization = False ,
150+ use_to_edge_transform_and_lower = True ,
151+ tosa_extensions = ["int16" ],
152+ )
153+
154+ pipeline .change_args (
155+ "quantize" ,
156+ get_symmetric_a16w8_rsqrt_quantizer (
157+ per_channel_quantization = False
158+ ),
159+ )
160+ # Run the pipeline
161+ pipeline .run ()
162+
163+
164+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
165+ @common .XfailIfNoCorstone300
166+ def test_rsqrt_int16_u55_INT16 (test_tensor : torch .Tensor ):
167+ """Test rsqrt operation with int16 quantization on U55"""
168+ pipeline = EthosU55PipelineINT [input_t1 ](
169+ Rsqrt (),
170+ test_tensor (),
171+ aten_op ,
172+ exir_ops = [],
173+ per_channel_quantization = True ,
174+ use_to_edge_transform_and_lower = True ,
175+ atol = 1e-02 ,
176+ rtol = 1e-02 ,
177+ run_on_fvp = True ,
178+ )
179+
180+ pipeline .change_args (
181+ "quantize" ,
182+ get_symmetric_a16w8_rsqrt_quantizer (
183+ per_channel_quantization = True
184+ ),
185+ )
186+ pipeline .run ()
187+
188+
189+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
190+ @common .XfailIfNoCorstone320
191+ def test_rsqrt_int16_u85_INT16 (test_tensor : torch .Tensor ):
192+ """Test rsqrt operation with int16 quantization on U85"""
193+ pipeline = EthosU85PipelineINT [input_t1 ](
194+ Rsqrt (),
195+ test_tensor (),
196+ aten_op ,
197+ exir_ops = [],
198+ use_to_edge_transform_and_lower = True ,
199+ atol = 1e-02 ,
200+ rtol = 1e-02 ,
201+ run_on_fvp = True ,
202+ )
203+
204+ pipeline .change_args (
205+ "quantize" ,
206+ get_symmetric_a16w8_rsqrt_quantizer (
207+ per_channel_quantization = False
208+ ),
209+ )
210+ pipeline .run ()
0 commit comments