Skip to content

Commit d75f450

Browse files
authored
Check dequantize_affine is idempotent (#309)
1 parent 88daa1a commit d75f450

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,24 @@
2525
_SEED = 1234
2626
torch.manual_seed(_SEED)
2727

28+
# Helper function to run a function twice
29+
# and verify that the result is the same.
30+
# Adds some verification to avoid side effects.
31+
# NOTE:
32+
# - Does not verify the args and kwargs are unchanged.
33+
# - Assumes the output is a single Tensor
34+
def check_idempotent(self, fn, *args, **kwargs):
35+
output0 = fn(*args, **kwargs)
36+
assert torch.is_tensor(output0)
37+
output1 = fn(*args, **kwargs)
38+
self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
39+
return output1
40+
41+
2842
class TestQuantPrimitives(unittest.TestCase):
2943
SEED = 123
3044

31-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
45+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
3246
def test_get_group_qparams_symmetric(self):
3347
"""
3448
Test that `get_group_qparams_symmetric` produces the exact same scales as
@@ -77,7 +91,7 @@ def test_choose_qparams_group_sym(self):
7791
self.assertTrue(torch.equal(scale, scale_ref))
7892
self.assertTrue(torch.equal(zero_point, zp_ref))
7993

80-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
94+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
8195
def test_choose_qparams_token_asym(self):
8296
input = torch.randn(10, 10)
8397
mapping_type = MappingType.ASYMMETRIC
@@ -127,7 +141,7 @@ def test_choose_qparams_tensor_sym(self):
127141
self.assertTrue(torch.equal(scale, scale_ref))
128142
self.assertTrue(torch.equal(zero_point, zp_ref))
129143

130-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
144+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
131145
def test_quantize_activation_per_token_abs_max(self):
132146
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
133147
input = torch.randn(10, 10)
@@ -148,15 +162,15 @@ def test_quantize_activation_per_token_abs_max(self):
148162
self.assertTrue(torch.equal(quantized, quantized_ref))
149163
self.assertTrue(torch.equal(scale, scale_ref))
150164

151-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
165+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
152166
def test_quantize_activation_per_token_abs_max_zero_input(self):
153167
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
154168
input = torch.zeros(10, 10)
155169
# make sure it still works
156170
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
157171

158172

159-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
173+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
160174
def test_quantize_activation_per_token_abs_max_dtype(self):
161175
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
162176
input = torch.zeros(10, 10, dtype=torch.bfloat16)
@@ -172,7 +186,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
172186
self.assertTrue(scale_ref.dtype, torch.float32)
173187

174188

175-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
189+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
176190
def test_quantize_dequantize_group_sym(self):
177191
input = torch.randn(10, 10)
178192
mapping_type = MappingType.SYMMETRIC
@@ -181,7 +195,7 @@ def test_quantize_dequantize_group_sym(self):
181195
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
182196

183197
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
184-
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
198+
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
185199

186200
group_size = 2
187201
quant_min = -128
@@ -196,7 +210,7 @@ def test_quantize_dequantize_group_sym(self):
196210
self.assertTrue(torch.equal(quantized, quantized_ref))
197211
self.assertTrue(torch.equal(dequantized, dequantized_ref))
198212

199-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
213+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
200214
def test_quantize_dequantize_channel_asym(self):
201215
input = torch.randn(10, 10)
202216
mapping_type = MappingType.ASYMMETRIC
@@ -205,7 +219,7 @@ def test_quantize_dequantize_channel_asym(self):
205219
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
206220
output_dtype = torch.float32
207221
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
208-
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
222+
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
209223

210224
axis = 1
211225
quant_min = -128
@@ -219,7 +233,7 @@ def test_quantize_dequantize_channel_asym(self):
219233
self.assertTrue(torch.equal(quantized, quantized_ref))
220234
self.assertTrue(torch.equal(dequantized, dequantized_ref))
221235

222-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
236+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
223237
def test_quantize_dequantize_tensor_asym(self):
224238
input = torch.randn(10, 10)
225239
mapping_type = MappingType.ASYMMETRIC
@@ -228,7 +242,7 @@ def test_quantize_dequantize_tensor_asym(self):
228242
output_dtype = torch.float32
229243
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
230244
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
231-
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
245+
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
232246

233247
axis = 1
234248
quant_min = -128
@@ -242,15 +256,15 @@ def test_quantize_dequantize_tensor_asym(self):
242256
self.assertTrue(torch.equal(quantized, quantized_ref))
243257
self.assertTrue(torch.equal(dequantized, dequantized_ref))
244258

245-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
259+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
246260
def test_quantize_dequantize_channel_asym_4d(self):
247261
input = torch.randn(3, 3, 10, 10)
248262
mapping_type = MappingType.ASYMMETRIC
249263
dtype = torch.int8
250264
block_size = (3, 3, 1, 10)
251265
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
252266
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
253-
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
267+
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
254268

255269
axis = 2
256270
quant_min = -128
@@ -264,15 +278,15 @@ def test_quantize_dequantize_channel_asym_4d(self):
264278
self.assertTrue(torch.equal(quantized, quantized_ref))
265279
self.assertTrue(torch.equal(dequantized, dequantized_ref))
266280

267-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
281+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
268282
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
269283
input = torch.randn(3, 3, 10, 10)
270284
mapping_type = MappingType.ASYMMETRIC
271285
dtype = torch.int8
272286
block_size = (3, 3, 2, 2)
273287
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
274288
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
275-
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
289+
dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
276290
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
277291
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)
278292

torchao/quantization/quant_primitives.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def dequantize_affine(
249249
Output:
250250
dequantized Tensor, with requested dtype or fp32
251251
"""
252+
252253
# TODO: validations
253254
# TODO: validate scale/zero_point dimensions are compatible with block_size
254255
assert input.dtype == input_dtype
@@ -266,14 +267,17 @@ def dequantize_affine(
266267
zero_point = zero_point.view(shape_after_reduction)
267268

268269
if zero_point_domain == ZeroPointDomain.INT:
269-
dequant = input.to(torch.int32)
270+
# Force a copy to avoid input modification due
271+
# to upcoming in-place operations.
272+
dequant = input.to(torch.int32, copy=True)
270273
if zero_point is not None:
271274
dequant -= zero_point.to(torch.int32)
272275
dequant = dequant.to(output_dtype)
273276
dequant *= scale
274277
else:
275278
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
276279
mid_point = (quant_max + quant_min + 1) / 2
280+
# This should allocate new memory and avoid input modification
277281
dequant = input - mid_point
278282
dequant = dequant.to(output_dtype)
279283
dequant *= scale

0 commit comments

Comments
 (0)