Skip to content

Commit b0b690f

Browse files
committed
Fix wrong scale eps applied
1 parent 5549da8 commit b0b690f

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
961961
torch.testing.assert_close(expected_quantized, quantized)
962962
torch.testing.assert_close(expected_dequantized, dequantized)
963963

964+
@parameterized.expand(
965+
[
966+
torch.float64,
967+
torch.float32,
968+
torch.bfloat16,
969+
torch.float16,
970+
]
971+
)
972+
def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype):
973+
# Fixed by #1770, the test will fail for all the variants
974+
# before that fix, and will pass afterwards.
975+
#
976+
# The scale value must be forcefully clamped, within
977+
# _choose_qparams_affine() function, (that
978+
# choose_qparams_affine() and others call into) to a large
979+
# enough number so that its reciprocal does not become Inf.
980+
# Otherwise during the quantization, by multiplying with scale
981+
# reciprocal, all the values will be quantized to Inf value,
982+
# except from zero value that would produce NaN (0*Inf) as
983+
# quantized value.
984+
#
985+
# The minimal normalized value for given floating point data
986+
# type is given by torch.finfo(hp_dtype).tiny - let's call
987+
# this value "tiny". It could be seen by checking, that for
988+
# all of torch.float64, torch.float32, torch.float16 and
989+
# torch.floatb16, denormalized number that is equal to tiny/4
990+
# will produce Inf as its reciprocal.
991+
#
992+
# Thus, to reproduce the problem, one would create a tensor
993+
# with such values that their absolute maximum, after being
994+
# divided with the range of quantized data (that is 57344 for
995+
# torch.float8_e5m2), would produce scale smaller than tiny/4.
996+
# Also, eps parameter should be set to value no greater than
997+
# tiny/4, as scale is clamped from below to that value. With
998+
# such inputs, choose_qparams_affine() will produce Inf as
999+
# scale value.
1000+
#
1001+
# Note that this may seem as contrieved reproducer. However,
1002+
# there are cases with existing code that would pass
1003+
# torch.finfo(torch.float32).eps as eps value, no matters of
1004+
# scale_dtype. The float16 has rather small range, so this
1005+
# value is well bellow torch.finfo(torch.float32).eps, and for
1006+
# such eps value, the code bellow would produce Inf scale even
1007+
# for float16 tensor that has 0.5 as its maximum value.
1008+
float8_dtype = torch.float8_e5m2
1009+
tiny = torch.finfo(hp_dtype).tiny
1010+
x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype)
1011+
scale, _ = choose_qparams_affine(
1012+
input=x,
1013+
mapping_type=MappingType.SYMMETRIC,
1014+
block_size=[1, 2],
1015+
target_dtype=float8_dtype,
1016+
eps=tiny / 4,
1017+
scale_dtype=hp_dtype,
1018+
preserve_zero=True,
1019+
zero_point_domain=ZeroPointDomain.NONE,
1020+
)
1021+
scale_reciprocal = scale.reciprocal()
1022+
assert not torch.any(torch.isinf(scale_reciprocal)).item()
1023+
9641024

9651025
if __name__ == "__main__":
9661026
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ def _choose_qparams_affine(
862862
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
863863
and `zero_point_domain`
864864
"""
865+
865866
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
866867
assert mapping_type in [
867868
MappingType.SYMMETRIC.name,
@@ -909,6 +910,16 @@ def _choose_qparams_affine(
909910
min_val_neg = min_val
910911
max_val_pos = max_val
911912

913+
# Prevent reciprocal of scale, calculated below, to become Inf.
914+
if torch.is_floating_point(max_val):
915+
# In this case, scale will be calculated below in
916+
# max_val.dtype.
917+
eps = max(eps, torch.finfo(max_val.dtype).tiny)
918+
else:
919+
# In this case, scale will be calculated below in
920+
# torch.float32 dtype.
921+
eps = max(eps, torch.finfo(torch.float32).tiny)
922+
912923
if (
913924
mapping_type == MappingType.SYMMETRIC.name
914925
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
@@ -969,7 +980,13 @@ def _choose_qparams_affine(
969980

970981
if zero_point is not None:
971982
zero_point = zero_point.to(dtype=zero_point_dtype)
972-
return scale.to(dtype=scale_dtype), zero_point
983+
scale = scale.to(dtype=scale_dtype)
984+
if torch.is_floating_point(scale):
985+
# Again, prevent scale reciprocal to become Inf.
986+
scale = scale.clamp(
987+
min=torch.finfo(scale_dtype).tiny, max=torch.finfo(scale_dtype).max
988+
)
989+
return scale, zero_point
973990

974991

975992
def choose_qparams_and_quantize_affine_qqq(

0 commit comments

Comments
 (0)