25
25
_SEED = 1234
26
26
torch .manual_seed (_SEED )
27
27
28
+ # Helper function to run a function twice
29
+ # and verify that the result is the same.
30
+ # Adds some verification to avoid side effects.
31
+ # NOTE:
32
+ # - Does not verify the args and kwargs are unchanged.
33
+ # - Assumes the output is a single Tensor
34
+ def check_idempotent (self , fn , * args , ** kwargs ):
35
+ output0 = fn (* args , ** kwargs )
36
+ assert torch .is_tensor (output0 )
37
+ output1 = fn (* args , ** kwargs )
38
+ self .assertTrue (torch .equal (output0 , output1 ), f"Expected given function { fn } to be idempotent." )
39
+ return output1
40
+
41
+
28
42
class TestQuantPrimitives (unittest .TestCase ):
29
43
SEED = 123
30
44
31
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
45
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch version is 2.3 or lower" )
32
46
def test_get_group_qparams_symmetric (self ):
33
47
"""
34
48
Test that `get_group_qparams_symmetric` produces the exact same scales as
@@ -77,7 +91,7 @@ def test_choose_qparams_group_sym(self):
77
91
self .assertTrue (torch .equal (scale , scale_ref ))
78
92
self .assertTrue (torch .equal (zero_point , zp_ref ))
79
93
80
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
94
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch version is 2.3 or lower" )
81
95
def test_choose_qparams_token_asym (self ):
82
96
input = torch .randn (10 , 10 )
83
97
mapping_type = MappingType .ASYMMETRIC
@@ -127,7 +141,7 @@ def test_choose_qparams_tensor_sym(self):
127
141
self .assertTrue (torch .equal (scale , scale_ref ))
128
142
self .assertTrue (torch .equal (zero_point , zp_ref ))
129
143
130
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
144
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
131
145
def test_quantize_activation_per_token_abs_max (self ):
132
146
from torchao .quantization .quant_primitives import quantize_activation_per_token_absmax
133
147
input = torch .randn (10 , 10 )
@@ -148,15 +162,15 @@ def test_quantize_activation_per_token_abs_max(self):
148
162
self .assertTrue (torch .equal (quantized , quantized_ref ))
149
163
self .assertTrue (torch .equal (scale , scale_ref ))
150
164
151
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
165
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
152
166
def test_quantize_activation_per_token_abs_max_zero_input (self ):
153
167
from torchao .quantization .quant_primitives import quantize_activation_per_token_absmax
154
168
input = torch .zeros (10 , 10 )
155
169
# make sure it still works
156
170
quantized_ref , scale_ref = quantize_activation_per_token_absmax (input )
157
171
158
172
159
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
173
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
160
174
def test_quantize_activation_per_token_abs_max_dtype (self ):
161
175
from torchao .quantization .quant_primitives import quantize_activation_per_token_absmax
162
176
input = torch .zeros (10 , 10 , dtype = torch .bfloat16 )
@@ -172,7 +186,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
172
186
self .assertTrue (scale_ref .dtype , torch .float32 )
173
187
174
188
175
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
189
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
176
190
def test_quantize_dequantize_group_sym (self ):
177
191
input = torch .randn (10 , 10 )
178
192
mapping_type = MappingType .SYMMETRIC
@@ -181,7 +195,7 @@ def test_quantize_dequantize_group_sym(self):
181
195
scale , zero_point = choose_qparams_affine (input , mapping_type , block_size , dtype , eps = torch .finfo (torch .float32 ).eps )
182
196
183
197
quantized = quantize_affine (input , block_size , scale , zero_point , dtype )
184
- dequantized = dequantize_affine ( quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
198
+ dequantized = check_idempotent ( self , dequantize_affine , quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
185
199
186
200
group_size = 2
187
201
quant_min = - 128
@@ -196,7 +210,7 @@ def test_quantize_dequantize_group_sym(self):
196
210
self .assertTrue (torch .equal (quantized , quantized_ref ))
197
211
self .assertTrue (torch .equal (dequantized , dequantized_ref ))
198
212
199
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
213
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
200
214
def test_quantize_dequantize_channel_asym (self ):
201
215
input = torch .randn (10 , 10 )
202
216
mapping_type = MappingType .ASYMMETRIC
@@ -205,7 +219,7 @@ def test_quantize_dequantize_channel_asym(self):
205
219
scale , zero_point = choose_qparams_affine (input , mapping_type , block_size , dtype , eps = torch .finfo (torch .float32 ).eps )
206
220
output_dtype = torch .float32
207
221
quantized = quantize_affine (input , block_size , scale , zero_point , dtype )
208
- dequantized = dequantize_affine ( quantized , block_size , scale , zero_point , dtype , output_dtype = output_dtype )
222
+ dequantized = check_idempotent ( self , dequantize_affine , quantized , block_size , scale , zero_point , dtype , output_dtype = output_dtype )
209
223
210
224
axis = 1
211
225
quant_min = - 128
@@ -219,7 +233,7 @@ def test_quantize_dequantize_channel_asym(self):
219
233
self .assertTrue (torch .equal (quantized , quantized_ref ))
220
234
self .assertTrue (torch .equal (dequantized , dequantized_ref ))
221
235
222
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
236
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
223
237
def test_quantize_dequantize_tensor_asym (self ):
224
238
input = torch .randn (10 , 10 )
225
239
mapping_type = MappingType .ASYMMETRIC
@@ -228,7 +242,7 @@ def test_quantize_dequantize_tensor_asym(self):
228
242
output_dtype = torch .float32
229
243
scale , zero_point = choose_qparams_affine (input , mapping_type , block_size , dtype , eps = torch .finfo (torch .float32 ).eps )
230
244
quantized = quantize_affine (input , block_size , scale , zero_point , dtype )
231
- dequantized = dequantize_affine ( quantized , block_size , scale , zero_point , dtype , output_dtype = output_dtype )
245
+ dequantized = check_idempotent ( self , dequantize_affine , quantized , block_size , scale , zero_point , dtype , output_dtype = output_dtype )
232
246
233
247
axis = 1
234
248
quant_min = - 128
@@ -242,15 +256,15 @@ def test_quantize_dequantize_tensor_asym(self):
242
256
self .assertTrue (torch .equal (quantized , quantized_ref ))
243
257
self .assertTrue (torch .equal (dequantized , dequantized_ref ))
244
258
245
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
259
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
246
260
def test_quantize_dequantize_channel_asym_4d (self ):
247
261
input = torch .randn (3 , 3 , 10 , 10 )
248
262
mapping_type = MappingType .ASYMMETRIC
249
263
dtype = torch .int8
250
264
block_size = (3 , 3 , 1 , 10 )
251
265
scale , zero_point = choose_qparams_affine (input , mapping_type , block_size , dtype , eps = torch .finfo (torch .float32 ).eps )
252
266
quantized = quantize_affine (input , block_size , scale , zero_point , dtype )
253
- dequantized = dequantize_affine ( quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
267
+ dequantized = check_idempotent ( self , dequantize_affine , quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
254
268
255
269
axis = 2
256
270
quant_min = - 128
@@ -264,15 +278,15 @@ def test_quantize_dequantize_channel_asym_4d(self):
264
278
self .assertTrue (torch .equal (quantized , quantized_ref ))
265
279
self .assertTrue (torch .equal (dequantized , dequantized_ref ))
266
280
267
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
281
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch version is 2.3 or lower" )
268
282
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction (self ):
269
283
input = torch .randn (3 , 3 , 10 , 10 )
270
284
mapping_type = MappingType .ASYMMETRIC
271
285
dtype = torch .int8
272
286
block_size = (3 , 3 , 2 , 2 )
273
287
scale , zero_point = choose_qparams_affine (input , mapping_type , block_size , dtype , eps = torch .finfo (torch .float32 ).eps )
274
288
quantized = quantize_affine (input , block_size , scale , zero_point , dtype )
275
- dequantized = dequantize_affine ( quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
289
+ dequantized = check_idempotent ( self , dequantize_affine , quantized , block_size , scale , zero_point , dtype , output_dtype = torch .float32 )
276
290
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
277
291
torch .testing .assert_close (dequantized , input , rtol = 2 , atol = 0.02 )
278
292
0 commit comments