3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import pytest
6
7
import torch
7
8
from executorch .backends .arm .quantizer .arm_quantizer import TOSAQuantizer
8
9
from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
@@ -52,7 +53,7 @@ def _get_32_bit_quant_config():
52
53
return qconfig
53
54
54
55
55
- def get_16bit_sigmoid_quantizer (tosa_str : str ):
56
+ def get_32bit_sigmoid_quantizer (tosa_str : str ):
56
57
tosa_spec = common .TosaSpecification .create_from_string (tosa_str )
57
58
quantizer = TOSAQuantizer (tosa_spec )
58
59
quantizer .set_global (_get_32_bit_quant_config ())
@@ -65,12 +66,12 @@ def get_16bit_sigmoid_quantizer(tosa_str: str):
65
66
66
67
input_t = tuple [torch .Tensor ]
67
68
test_data_suite = {
68
- "ones" : ( torch .ones (10 , 10 , 10 ), ),
69
- "rand" : ( torch .rand (10 , 10 ) - 0.5 ,) ,
70
- "rand_4d" : ( torch .rand (1 , 10 , 10 , 10 ), ),
71
- "randn_pos" : ( torch .randn (10 ) + 10 ,) ,
72
- "randn_neg" : ( torch .randn (10 ) - 10 ,) ,
73
- "ramp" : ( torch .arange (- 16 , 16 , 0.2 ), ),
69
+ "ones" : lambda : torch .ones (10 , 10 , 10 ),
70
+ "rand" : lambda : torch .rand (10 , 10 ) - 0.5 ,
71
+ "rand_4d" : lambda : torch .rand (1 , 10 , 10 , 10 ),
72
+ "randn_pos" : lambda : torch .randn (10 ) + 10 ,
73
+ "randn_neg" : lambda : torch .randn (10 ) - 10 ,
74
+ "ramp" : lambda : torch .arange (- 16 , 16 , 0.2 ),
74
75
}
75
76
76
77
@@ -96,28 +97,28 @@ def forward(self, x):
96
97
97
98
98
99
@common .parametrize ("test_data" , test_data_suite )
100
+ @pytest .mark .flaky (reruns = 5 )
99
101
def test_sigmoid_tosa_BI (test_data ):
100
102
pipeline = TosaPipelineBI (
101
103
Sigmoid (),
102
- test_data ,
104
+ ( test_data (),) ,
103
105
Sigmoid .aten_op ,
104
106
Sigmoid .exir_op ,
105
107
)
106
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
108
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
107
109
pipeline .run ()
108
110
109
111
110
112
@common .parametrize ("test_data" , test_data_suite )
113
+ @pytest .mark .flaky (reruns = 5 )
111
114
def test_sigmoid_add_sigmoid_tosa_BI (test_data ):
112
115
pipeline = TosaPipelineBI (
113
116
SigmoidAddSigmoid (),
114
- test_data ,
117
+ ( test_data (),) ,
115
118
Sigmoid .aten_op ,
116
119
Sigmoid .exir_op ,
117
120
)
118
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
119
- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
120
-
121
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
121
122
pipeline .run ()
122
123
123
124
@@ -129,16 +130,19 @@ def test_sigmoid_add_sigmoid_tosa_BI(test_data):
129
130
"rand" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
130
131
"rand_4d" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
131
132
"randn_pos" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
133
+ "randn_neg" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
132
134
"ramp" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
133
135
},
136
+ # int16 tables are not supported, but some tests happen to pass regardless.
137
+ # Set them to xfail but strict=False -> ok if they pass.
138
+ strict = False ,
134
139
)
135
140
@common .XfailIfNoCorstone300
136
141
def test_sigmoid_tosa_u55 (test_data ):
137
142
pipeline = EthosU55PipelineBI (
138
- Sigmoid (), test_data , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
143
+ Sigmoid (), ( test_data (),) , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
139
144
)
140
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
141
- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
145
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
142
146
pipeline .run ()
143
147
144
148
@@ -153,29 +157,31 @@ def test_sigmoid_tosa_u55(test_data):
153
157
"randn_neg" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
154
158
"ramp" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
155
159
},
160
+ # int16 tables are not supported, but some tests happen to pass regardless.
161
+ # Set them to xfail but strict=False -> ok if they pass.
162
+ strict = False ,
156
163
)
157
164
@common .XfailIfNoCorstone300
158
165
def test_sigmoid_add_sigmoid_tosa_u55 (test_data ):
159
166
pipeline = EthosU55PipelineBI (
160
167
SigmoidAddSigmoid (),
161
- test_data ,
168
+ ( test_data (),) ,
162
169
Sigmoid .aten_op ,
163
170
Sigmoid .exir_op ,
164
171
run_on_fvp = True ,
165
172
)
166
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
167
- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
173
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
168
174
pipeline .run ()
169
175
170
176
171
177
@common .parametrize ("test_data" , test_data_suite )
178
+ @pytest .mark .flaky (reruns = 5 )
172
179
@common .XfailIfNoCorstone320
173
180
def test_sigmoid_tosa_u85 (test_data ):
174
181
pipeline = EthosU85PipelineBI (
175
- Sigmoid (), test_data , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
182
+ Sigmoid (), ( test_data (),) , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
176
183
)
177
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
178
- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
184
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
179
185
pipeline .run ()
180
186
181
187
@@ -186,15 +192,15 @@ def test_sigmoid_tosa_u85(test_data):
186
192
"ramp" : "AssertionError: Output 0 does not match reference output." ,
187
193
},
188
194
)
195
+ @pytest .mark .flaky (reruns = 5 )
189
196
@common .XfailIfNoCorstone320
190
197
def test_sigmoid_add_sigmoid_tosa_u85 (test_data ):
191
198
pipeline = EthosU85PipelineBI (
192
199
SigmoidAddSigmoid (),
193
- test_data ,
200
+ ( test_data (),) ,
194
201
Sigmoid .aten_op ,
195
202
Sigmoid .exir_op ,
196
203
run_on_fvp = True ,
197
204
)
198
- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
199
- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
205
+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
200
206
pipeline .run ()
0 commit comments