Skip to content

Commit 9c60424

Browse files
committed
Add general fake_quantize_affine op
Summary: Add a general `fake_quantize_affine` op that simulates `quantize_affine` + `dequantize_affine` but without casting the intermediate quantized values to lower bit-widths, intended for quantization-aware training (QAT). Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine
1 parent 12ac498 commit 9c60424

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import unittest
1010
import torch
1111
from torchao.quantization.quant_primitives import (
12+
fake_quantize_affine,
1213
quantize_affine,
1314
dequantize_affine,
1415
choose_qparams_affine,
@@ -503,5 +504,24 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
503504

504505
self.assertTrue(torch.equal(w_bf16, w_bf16_ref))
505506

507+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
508+
def test_fake_quantize_affine(self):
509+
input = torch.randn(10, 10)
510+
511+
mapping_type = MappingType.SYMMETRIC
512+
block_size = list(input.shape)
513+
for i in range(len(block_size) - 1):
514+
block_size[i] = 1
515+
dtype = torch.int8
516+
eps = 1e-5
517+
quant_min = -127
518+
quant_max = 127
519+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)
520+
521+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
522+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
523+
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
524+
torch.testing.assert_close(dequantized, fake_quantized)
525+
506526
if __name__ == "__main__":
507527
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"choose_qparams_affine",
2424
"quantize_affine",
2525
"dequantize_affine",
26+
"fake_quantize_affine",
2627
]
2728

2829
class MappingType(Enum):
@@ -203,14 +204,34 @@ def _quantize_affine(
203204
output_dtype: torch.dtype,
204205
quant_min: Optional[int] = None,
205206
quant_max: Optional[int] = None,
206-
zero_point_domain: str = "INT",
207+
zero_point_domain: str = ZeroPointDomain.INT.name,
207208
) -> torch.Tensor:
208209
"""op definition that has compatible signatures with custom op library
209210
"""
211+
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
212+
return _quantize_affine_no_dtype_cast(
213+
input,
214+
block_size,
215+
scale,
216+
zero_point,
217+
quant_min,
218+
quant_max,
219+
zero_point_domain,
220+
).to(output_dtype)
221+
222+
223+
def _quantize_affine_no_dtype_cast(
224+
input: torch.Tensor,
225+
block_size: List[int],
226+
scale: torch.Tensor,
227+
zero_point: Optional[torch.Tensor],
228+
quant_min: int,
229+
quant_max: int,
230+
zero_point_domain: str = ZeroPointDomain.INT.name,
231+
) -> torch.Tensor:
210232
# TODO: validations
211233
# TODO: validate scale/zero_point dimensions are compatible with block_size
212234
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
213-
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
214235
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
215236
original_shape = input.shape
216237
input = input.view(shape_for_reduction)
@@ -224,7 +245,7 @@ def _quantize_affine(
224245
if zero_point_domain == ZeroPointDomain.INT.name:
225246
quant = torch.clamp(
226247
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
227-
).to(output_dtype)
248+
)
228249
else:
229250
assert zero_point_domain == ZeroPointDomain.FLOAT.name
230251
mid_point = (quant_max + quant_min + 1) / 2
@@ -233,11 +254,12 @@ def _quantize_affine(
233254
torch.clamp(
234255
torch.round((input - min_val) / scale),
235256
quant_min, quant_max)
236-
).to(output_dtype)
257+
)
237258
quant = quant.view(original_shape)
238259

239260
return quant
240261

262+
241263
def dequantize_affine(
242264
input: torch.Tensor,
243265
block_size: Tuple[int, ...],
@@ -283,6 +305,7 @@ def dequantize_affine(
283305
output_dtype=output_dtype,
284306
)
285307

308+
286309
@register_custom_op
287310
def _dequantize_affine(
288311
input: torch.Tensor,
@@ -292,7 +315,7 @@ def _dequantize_affine(
292315
input_dtype: torch.dtype,
293316
quant_min: Optional[int] = None,
294317
quant_max: Optional[int] = None,
295-
zero_point_domain: str = "INT",
318+
zero_point_domain: str = ZeroPointDomain.INT.name,
296319
output_dtype: torch.dtype = torch.float32,
297320
) -> torch.Tensor:
298321
"""op definition that has compatible signatures with custom op library
@@ -303,7 +326,28 @@ def _dequantize_affine(
303326
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
304327
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
305328
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
329+
return _dequantize_affine_no_dtype_check(
330+
input,
331+
block_size,
332+
scale,
333+
zero_point,
334+
quant_min,
335+
quant_max,
336+
zero_point_domain,
337+
output_dtype,
338+
)
339+
306340

341+
def _dequantize_affine_no_dtype_check(
342+
input: torch.Tensor,
343+
block_size: List[int],
344+
scale: torch.Tensor,
345+
zero_point: Optional[torch.Tensor],
346+
quant_min: int,
347+
quant_max: int,
348+
zero_point_domain: str = ZeroPointDomain.INT.name,
349+
output_dtype: torch.dtype = torch.float32,
350+
) -> torch.Tensor:
307351
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
308352
original_shape = input.shape
309353
input = input.view(shape_for_reduction)
@@ -335,6 +379,62 @@ def _dequantize_affine(
335379

336380
return dequant.view(original_shape).to(output_dtype)
337381

382+
383+
def fake_quantize_affine(
384+
input: torch.Tensor,
385+
block_size: Tuple[int, ...],
386+
scale: torch.Tensor,
387+
zero_point: Optional[torch.Tensor],
388+
quant_dtype: torch.dtype,
389+
quant_min: Optional[int] = None,
390+
quant_max: Optional[int] = None,
391+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
392+
) -> torch.Tensor:
393+
"""
394+
General fake quantize op for quantization-aware training (QAT).
395+
This is equivalent to calling `quantize_affine` + `dequantize_affine`
396+
but without the dtype casts.
397+
398+
Args:
399+
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
400+
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
401+
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
402+
scale (float): quantization parameter for affine quantization
403+
zero_point (int): quantization parameter for affine quantization
404+
quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
405+
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
406+
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
407+
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
408+
if zero_point is in integer domain, zero point is added to the quantized integer value during
409+
quantization
410+
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
411+
value during quantization
412+
default is ZeroPointDomain.INT
413+
"""
414+
input_dtype = input.dtype
415+
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
416+
q = _quantize_affine_no_dtype_cast(
417+
input,
418+
block_size,
419+
scale,
420+
zero_point,
421+
quant_min,
422+
quant_max,
423+
zero_point_domain.name,
424+
)
425+
dq = _dequantize_affine_no_dtype_check(
426+
q,
427+
block_size,
428+
scale,
429+
zero_point,
430+
quant_min,
431+
quant_max,
432+
zero_point_domain.name,
433+
output_dtype=input_dtype,
434+
)
435+
return dq
436+
437+
338438
def choose_qparams_affine(
339439
input: torch.Tensor,
340440
mapping_type: MappingType,

0 commit comments

Comments
 (0)