Skip to content

Commit 535ac19

Browse files
committed
Fix wrong scale eps applied
1 parent d00ee41 commit 535ac19

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,10 +944,16 @@ def _choose_qparams_affine(
944944
else:
945945
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
946946
scale = torch.clamp(scale, min=eps)
947+
if torch.is_floating_point(scale):
948+
# Prevent 1.0 / scale to become Inf.
949+
scale = torch.clamp(scale, min=2 * torch.finfo(scale.dtype).tiny)
947950
else:
948951
assert mapping_type == MappingType.ASYMMETRIC.name
949952
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
950953
scale = torch.clamp(scale, min=eps)
954+
if torch.is_floating_point(scale):
955+
# Prevent 1.0 / scale to become Inf.
956+
scale = torch.clamp(scale, min=2 * torch.finfo(scale.dtype).tiny)
951957
if zero_point_domain == ZeroPointDomain.NONE.name:
952958
zero_point = None
953959
else:

0 commit comments

Comments
 (0)