Skip to content

Commit 3dd16c9

Browse files
authored
Fix CI after quantize op change in PyTorch core (#244)
Summary: pytorch/pytorch#125781 recently changed the numerics of the quantize op subtly. This commit fixes the numerics mismatch caused by this PR by making our quantize ops consistent with the ones in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_quantize_dequantize_group_sym python test/quantization/test_quant_api.py TestQuantFlow.test_quantized_tensor_subclass_8da4w Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
1 parent 10da375 commit 3dd16c9

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
11241124
m_c = torch.compile(m, mode="max-autotune")
11251125
y_wo, (code,) = run_and_get_code(m_c, x)
11261126
sqnr = compute_error(y_ref, y_wo)
1127-
self.assertGreater(sqnr, 43.0)
1127+
self.assertGreaterEqual(sqnr, 42.75)
11281128
if device == "cuda":
11291129
self.assertTrue("mixed_mm" in code)
11301130

test/quantization/test_qat.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
fake_quantize_per_token,
1919
)
2020
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
21-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
21+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
2222

2323

2424
# TODO: put this in a common test utils file
@@ -58,7 +58,7 @@ def _get_qmin_qmax(self, n_bit: int):
5858
qmax = 2 ** (n_bit - 1) - 1
5959
return (qmin, qmax)
6060

61-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
61+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
6262
def test_fake_quantize_per_channel_group(self):
6363
n_bit = 4
6464
(qmin, qmax) = self._get_qmin_qmax(n_bit)
@@ -84,7 +84,7 @@ def test_fake_quantize_per_channel_group(self):
8484
)
8585
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
8686

87-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
87+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
8888
def test_fake_quantize_per_token(self):
8989
(qmin, qmax) = self._get_qmin_qmax(8)
9090

@@ -130,7 +130,7 @@ def _set_ptq_weight(
130130
ptq_linear.scales = s
131131
ptq_linear.zeros = zp
132132

133-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
133+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
134134
def test_qat_8da4w_linear(self):
135135
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
136136
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@@ -155,7 +155,7 @@ def test_qat_8da4w_linear(self):
155155
ptq_out = ptq_linear(x2)
156156
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
157157

158-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
158+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
159159
def test_qat_8da4w_quantizer(self):
160160
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
161161
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
@@ -189,7 +189,7 @@ def test_qat_8da4w_quantizer(self):
189189
for k in ptq_state_dict.keys():
190190
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
191191

192-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
192+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
193193
def test_qat_8da4w_quantizer_meta_weights(self):
194194
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
195195

@@ -201,7 +201,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
201201
qat_model = qat_quantizer.prepare(m)
202202
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))
203203

204-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
204+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
205205
def test_qat_8da4w_quantizer_disable_fake_quant(self):
206206
"""
207207
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
@@ -254,7 +254,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
254254
qat_out2 = qat_model2(*x2)
255255
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)
256256

257-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
257+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
258258
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
259259
"""
260260
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.

test/quantization/test_quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
156156
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
157157

158158

159-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
159+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
160160
def test_quantize_dequantize_group_sym(self):
161161
input = torch.randn(10, 10)
162162
mapping_type = MappingType.SYMMETRIC

torchao/quantization/prototype/qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max):
209209
# which rounds first before adding the zero points. However, this
210210
# is what `quantize_per_channel_group` and `quantize_per_token`
211211
# do and here we try to match that behavior as closely as possible.
212-
q = input.div(scales).add(zero_points).round()
212+
q = input.mul(1.0 / scales).add(zero_points).round()
213213
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
214214
# TODO: do we need this mask?
215215
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def quantize_affine(
201201

202202
if zero_point_domain == ZeroPointDomain.INT:
203203
quant = torch.clamp(
204-
torch.round(input / scale) + zero_point, quant_min, quant_max
204+
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
205205
).to(output_dtype)
206206
else:
207207
assert zero_point_domain == ZeroPointDomain.FLOAT

0 commit comments

Comments
 (0)