Skip to content

Commit bf5c816

Browse files
committed
Dedup _choose_qparams_per_token_asymmetric
1 parent b91b6be commit bf5c816

File tree

2 files changed

+5
-50
lines changed

2 files changed

+5
-50
lines changed

test/quantization/test_qat.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1515
from torchao.quantization.prototype.qat import (
16-
_choose_qparams_per_token_asymmetric,
1716
fake_quantize_per_channel_group,
1817
fake_quantize_per_token,
1918
)
@@ -91,8 +90,7 @@ def test_fake_quantize_per_token(self):
9190
torch.manual_seed(self.SEED)
9291
x = torch.randn(100, 256).requires_grad_()
9392
x2 = copy.deepcopy(x)
94-
# TODO: use torch.ops.aten.quantized_decomposed version instead
95-
(s, zp) = _choose_qparams_per_token_asymmetric(
93+
(s, zp) = torch.ops.quantized_decomposed._choose_qparams_per_token_asymmetric_impl(
9694
x,
9795
torch.int8, # not used
9896
)

torchao/quantization/prototype/qat.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def disable_fake_quant(self):
142142
def forward(self, x: torch.Tensor) -> torch.Tensor:
143143
# activations: int8 dynamic asymmetric quant
144144
if self._fake_quant_enabled:
145-
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
145+
(
146+
act_scales,
147+
act_zp
148+
) = torch.ops.quantized_decomposed._choose_qparams_per_token_asymmetric_impl(
146149
x, torch.int8, # dtype not used
147150
)
148151
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
@@ -269,49 +272,3 @@ def fake_quantize_per_token(
269272
return _GenericFakeQuantize.apply(
270273
input, scales, zero_points, quant_min, quant_max,
271274
)
272-
273-
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274-
# The version in pytorch does not have backward support yet so we add
275-
# it here for now until https://github.com/pytorch/pytorch/pull/123452
276-
# is landed.
277-
def _choose_qparams_per_token_asymmetric(
278-
input: torch.Tensor,
279-
dtype: torch.dtype,
280-
) -> Tuple[torch.Tensor, torch.Tensor]:
281-
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282-
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
283-
every N elements with the same quantization parameter. The dimension for scales/zero_points
284-
will be (M1 * M2 ... * Mn)
285-
286-
Args:
287-
input (torch.Tensor): original float32/float16 Tensor
288-
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
289-
290-
Returns:
291-
scales and zero_points, both float32 Tensors
292-
"""
293-
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
294-
qmin, qmax = -128, 127
295-
min_val = torch.amin(input, dim=-1, keepdim=True)
296-
max_val = torch.amax(input, dim=-1, keepdim=True)
297-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
298-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
299-
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
300-
301-
# scale
302-
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
303-
scale = scale.clamp(min=eps)
304-
305-
# zero point
306-
descaled_min = min_val_neg / scale
307-
descaled_max = max_val_pos / scale
308-
zero_point_from_min_error = qmin + descaled_min
309-
zero_point_from_max_error = qmax + descaled_max
310-
zero_point = torch.where(
311-
zero_point_from_min_error + zero_point_from_max_error > 0,
312-
qmin - descaled_min,
313-
qmax - descaled_max,
314-
)
315-
zero_point = torch.clamp(zero_point, qmin, qmax).round()
316-
317-
return scale.to(torch.float32), zero_point.to(torch.float32)

0 commit comments

Comments
 (0)