Skip to content

Commit dcc44c4

Browse files
committed
Fix wrong scale eps applied
1 parent 665dac0 commit dcc44c4

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,69 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
961961
torch.testing.assert_close(expected_quantized, quantized)
962962
torch.testing.assert_close(expected_dequantized, dequantized)
963963

964+
@parameterized.expand(
965+
[
966+
torch.float64,
967+
torch.float32,
968+
torch.bfloat16,
969+
torch.float16,
970+
]
971+
)
972+
def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype):
973+
# Fixed by #1770, the test will fail for all the variants
974+
# before that fix, and will pass afterwards.
975+
#
976+
# The scale value must be checked, within
977+
# _choose_qparams_affine() function, (that
978+
# choose_qparams_affine() and others call into) to a large
979+
# enough number so that its reciprocal does not become Inf.
980+
# Otherwise during the quantization, by multiplying with scale
981+
# reciprocal, all the values will be quantized to Inf value,
982+
# except from zero value that would produce NaN (0*Inf) as
983+
# quantized value.
984+
#
985+
# The minimal normalized value for given floating point data
986+
# type is given by torch.finfo(hp_dtype).tiny - let's call
987+
# this value "tiny". It could be checked, that for
988+
# all of torch.float64, torch.float32, torch.float16 and
989+
# torch.floatb16, denormalized number that is equal to tiny/4
990+
# will produce Inf as its reciprocal.
991+
#
992+
# Thus, to reproduce the problem, one would create a tensor
993+
# with such values that their absolute maximum, after being
994+
# divided with the range of quantized data (that is 57344 for
995+
# torch.float8_e5m2), would produce scale smaller than tiny/4.
996+
# Also, eps parameter should be set to value no greater than
997+
# tiny/4, as scale is clamped from below to that value. With
998+
# such inputs, choose_qparams_affine() would produce Inf as
999+
# scale value, if not checked for it.
1000+
#
1001+
# Note that this may seem as contrieved reproducer. However,
1002+
# there are cases with existing code that would pass
1003+
# torch.finfo(torch.float32).eps as eps value, no matters of
1004+
# scale_dtype. The float16 has rather small range, so this
1005+
# value is well bellow torch.finfo(torch.float32).eps, and for
1006+
# such eps value, the code bellow would produce Inf scale even
1007+
# for float16 tensor that has 0.5 as its maximum value.
1008+
float8_dtype = torch.float8_e5m2
1009+
tiny = torch.finfo(hp_dtype).tiny
1010+
x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype)
1011+
1012+
import pytest
1013+
1014+
with pytest.raises(AssertionError) as exc_info:
1015+
scale, _ = choose_qparams_affine(
1016+
input=x,
1017+
mapping_type=MappingType.SYMMETRIC,
1018+
block_size=[1, 2],
1019+
target_dtype=float8_dtype,
1020+
eps=tiny / 4,
1021+
scale_dtype=hp_dtype,
1022+
preserve_zero=True,
1023+
zero_point_domain=ZeroPointDomain.NONE,
1024+
)
1025+
assert str(exc_info.value) == "Invalid scale value"
1026+
9641027

9651028
if __name__ == "__main__":
9661029
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ def _choose_qparams_affine(
862862
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
863863
and `zero_point_domain`
864864
"""
865+
865866
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
866867
assert mapping_type in [
867868
MappingType.SYMMETRIC.name,
@@ -950,6 +951,13 @@ def _choose_qparams_affine(
950951
assert mapping_type == MappingType.ASYMMETRIC.name
951952
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
952953
scale = torch.clamp(scale, min=eps)
954+
# Prevent scale reciprocal to become Inf.
955+
assert torch.all(
956+
scale
957+
>= torch.finfo(
958+
scale.dtype if torch.is_floating_point(scale) else torch.float32
959+
).tiny
960+
), "Invalid scale value"
953961
if zero_point_domain == ZeroPointDomain.NONE.name:
954962
zero_point = None
955963
elif zero_point_domain == ZeroPointDomain.INT.name:
@@ -969,7 +977,11 @@ def _choose_qparams_affine(
969977

970978
if zero_point is not None:
971979
zero_point = zero_point.to(dtype=zero_point_dtype)
972-
return scale.to(dtype=scale_dtype), zero_point
980+
scale = scale.to(dtype=scale_dtype)
981+
# Prevent scale reciprocal to become Inf.
982+
if torch.is_floating_point(scale):
983+
assert torch.all(scale >= torch.finfo(scale_dtype).tiny), "Invalid scale value"
984+
return scale, zero_point
973985

974986

975987
def choose_qparams_and_quantize_affine_qqq(

0 commit comments

Comments
 (0)