@@ -961,6 +961,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
961
961
torch .testing .assert_close (expected_quantized , quantized )
962
962
torch .testing .assert_close (expected_dequantized , dequantized )
963
963
964
+ @parameterized .expand (
965
+ [
966
+ torch .float64 ,
967
+ torch .float32 ,
968
+ torch .bfloat16 ,
969
+ torch .float16 ,
970
+ ]
971
+ )
972
+ def test_choose_qparams_affine_for_inf_scale_reciprocal (self , hp_dtype ):
973
+ # Fixed by #1770, the test will fail for all the variants
974
+ # before that fix, and will pass afterwards.
975
+ #
976
+ # The scale value must be forcefully clamped, within
977
+ # _choose_qparams_affine() function, (that
978
+ # choose_qparams_affine() and others call into) to a large
979
+ # enough number so that its reciprocal does not become Inf.
980
+ # Otherwise during the quantization, by multiplying with scale
981
+ # reciprocal, all the values will be quantized to Inf value,
982
+ # except from zero value that would produce NaN (0*Inf) as
983
+ # quantized value.
984
+ #
985
+ # The minimal normalized value for given floating point data
986
+ # type is given by torch.finfo(hp_dtype).tiny - let's call
987
+ # this value "tiny". It could be seen by checking, that for
988
+ # all of torch.float64, torch.float32, torch.float16 and
989
+ # torch.floatb16, denormalized number that is equal to tiny/4
990
+ # will produce Inf as its reciprocal.
991
+ #
992
+ # Thus, to reproduce the problem, one would create a tensor
993
+ # with such values that their absolute maximum, after being
994
+ # divided with the range of quantized data (that is 57344 for
995
+ # torch.float8_e5m2), would produce scale smaller than tiny/4.
996
+ # Also, eps parameter should be set to value no greater than
997
+ # tiny/4, as scale is clamped from below to that value. With
998
+ # such inpujts, choose_qparams_affine() will produce Inf as
999
+ # scale value.
1000
+ #
1001
+ # Note that this may seem as contrieved reproduces. However,
1002
+ # there are cases with existing code that would pass
1003
+ # torch.finfo(torch.float32).eps as eps value, no matters of
1004
+ # scale_dtype. The float16 has rather small range, so this
1005
+ # value is well bellow torch.finfo(torch.float32).eps, and for
1006
+ # such eps value, the code bellow would produce Inf scale even
1007
+ # for float16 tensor that has 0.5 as its maximum value.
1008
+ float8_dtype = torch .float8_e5m2
1009
+ tiny = torch .finfo (hp_dtype ).tiny
1010
+ x = torch .tensor ([[0 , 100 * tiny ]], dtype = hp_dtype )
1011
+ scale , _ = choose_qparams_affine (
1012
+ input = x ,
1013
+ mapping_type = MappingType .SYMMETRIC ,
1014
+ block_size = [1 , 2 ],
1015
+ target_dtype = float8_dtype ,
1016
+ eps = tiny / 4 ,
1017
+ scale_dtype = hp_dtype ,
1018
+ preserve_zero = True ,
1019
+ zero_point_domain = ZeroPointDomain .NONE ,
1020
+ )
1021
+ scale_reciprocal = scale .reciprocal ()
1022
+ assert not torch .any (torch .isinf (scale_reciprocal )).item ()
1023
+
964
1024
965
1025
if __name__ == "__main__" :
966
1026
unittest .main ()
0 commit comments