@@ -961,6 +961,69 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
961
961
torch .testing .assert_close (expected_quantized , quantized )
962
962
torch .testing .assert_close (expected_dequantized , dequantized )
963
963
964
+ @parameterized .expand (
965
+ [
966
+ torch .float64 ,
967
+ torch .float32 ,
968
+ torch .bfloat16 ,
969
+ torch .float16 ,
970
+ ]
971
+ )
972
+ def test_choose_qparams_affine_for_inf_scale_reciprocal (self , hp_dtype ):
973
+ # Fixed by #1770, the test will fail for all the variants
974
+ # before that fix, and will pass afterwards.
975
+ #
976
+ # The scale value must be checked, within
977
+ # _choose_qparams_affine() function, (that
978
+ # choose_qparams_affine() and others call into) to a large
979
+ # enough number so that its reciprocal does not become Inf.
980
+ # Otherwise during the quantization, by multiplying with scale
981
+ # reciprocal, all the values will be quantized to Inf value,
982
+ # except from zero value that would produce NaN (0*Inf) as
983
+ # quantized value.
984
+ #
985
+ # The minimal normalized value for given floating point data
986
+ # type is given by torch.finfo(hp_dtype).tiny - let's call
987
+ # this value "tiny". It could be checked, that for
988
+ # all of torch.float64, torch.float32, torch.float16 and
989
+ # torch.floatb16, denormalized number that is equal to tiny/4
990
+ # will produce Inf as its reciprocal.
991
+ #
992
+ # Thus, to reproduce the problem, one would create a tensor
993
+ # with such values that their absolute maximum, after being
994
+ # divided with the range of quantized data (that is 57344 for
995
+ # torch.float8_e5m2), would produce scale smaller than tiny/4.
996
+ # Also, eps parameter should be set to value no greater than
997
+ # tiny/4, as scale is clamped from below to that value. With
998
+ # such inputs, choose_qparams_affine() would produce Inf as
999
+ # scale value, if not checked for it.
1000
+ #
1001
+ # Note that this may seem as contrieved reproducer. However,
1002
+ # there are cases with existing code that would pass
1003
+ # torch.finfo(torch.float32).eps as eps value, no matters of
1004
+ # scale_dtype. The float16 has rather small range, so this
1005
+ # value is well bellow torch.finfo(torch.float32).eps, and for
1006
+ # such eps value, the code bellow would produce Inf scale even
1007
+ # for float16 tensor that has 0.5 as its maximum value.
1008
+ float8_dtype = torch .float8_e5m2
1009
+ tiny = torch .finfo (hp_dtype ).tiny
1010
+ x = torch .tensor ([[0 , 100 * tiny ]], dtype = hp_dtype )
1011
+
1012
+ import pytest
1013
+
1014
+ with pytest .raises (AssertionError ) as exc_info :
1015
+ scale , _ = choose_qparams_affine (
1016
+ input = x ,
1017
+ mapping_type = MappingType .SYMMETRIC ,
1018
+ block_size = [1 , 2 ],
1019
+ target_dtype = float8_dtype ,
1020
+ eps = tiny / 4 ,
1021
+ scale_dtype = hp_dtype ,
1022
+ preserve_zero = True ,
1023
+ zero_point_domain = ZeroPointDomain .NONE ,
1024
+ )
1025
+ assert str (exc_info .value ) == "Invalid scale value"
1026
+
964
1027
965
1028
if __name__ == "__main__" :
966
1029
unittest .main ()
0 commit comments