@@ -204,6 +204,48 @@ def test_upscale():
204
204
verify (mod , (golden_data , golden_output ))
205
205
206
206
207
+ def test_non_power_of_two ():
208
+ for rounding in roundings :
209
+ mod = get_mod (
210
+ data_shape = (32 ,),
211
+ data_dtype = "int32" ,
212
+ out_dtype = "int8" ,
213
+ input_scale = 1 ,
214
+ output_scale = 3 ,
215
+ rounding = rounding ,
216
+ )
217
+
218
+ # Try positive values
219
+ golden_data = np .multiply (np .arange (0 , 32 , 1 ).astype ("int32" ), 3 )
220
+ golden_output = np .arange (0 , 32 , 1 )
221
+ verify (mod , (golden_data , golden_output ))
222
+
223
+ # Try negative values
224
+ golden_data = np .multiply (np .arange (0 , - 32 , - 1 ).astype ("int32" ), 3 )
225
+ golden_output = np .arange (0 , - 32 , - 1 )
226
+ verify (mod , (golden_data , golden_output ))
227
+
228
+ # Try a different scale
229
+ mod = get_mod (
230
+ data_shape = (32 ,),
231
+ data_dtype = "int32" ,
232
+ out_dtype = "int8" ,
233
+ input_scale = 3 ,
234
+ output_scale = 1 ,
235
+ rounding = rounding ,
236
+ )
237
+
238
+ # Try positive values
239
+ golden_data = np .arange (0 , 32 , 1 ).astype ("int32" )
240
+ golden_output = np .multiply (golden_data , 3 )
241
+ verify (mod , (golden_data , golden_output ))
242
+
243
+ # Try negative values
244
+ golden_data = np .arange (0 , - 32 , - 1 ).astype ("int32" )
245
+ golden_output = np .multiply (golden_data , 3 )
246
+ verify (mod , (golden_data , golden_output ))
247
+
248
+
207
249
def test_saturation ():
208
250
for rounding in roundings :
209
251
mod = get_mod (
@@ -397,6 +439,7 @@ def test_per_channel_different_scale():
397
439
test_same_scale ()
398
440
test_downscale ()
399
441
test_upscale ()
442
+ test_non_power_of_two ()
400
443
test_saturation ()
401
444
test_zero_point ()
402
445
test_per_channel_same_scale ()
0 commit comments