File tree 1 file changed +7
-1
lines changed 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -873,7 +873,9 @@ def _choose_qparams_affine(
873
873
f"Only symmetric quantization is supported for FP8 types, got { mapping_type } "
874
874
)
875
875
876
+ scale_device = None
876
877
if input is not None :
878
+ scale_device = input .device
877
879
if scale_dtype is None :
878
880
scale_dtype = input .dtype
879
881
if eps is None :
@@ -902,6 +904,8 @@ def _choose_qparams_affine(
902
904
if eps is None :
903
905
eps = torch .finfo (min_val .dtype ).eps
904
906
907
+ scale_device = min_val .device
908
+
905
909
if preserve_zero :
906
910
min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
907
911
max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
@@ -948,7 +952,9 @@ def _choose_qparams_affine(
948
952
scale = torch .clamp (scale , min = eps )
949
953
else :
950
954
assert mapping_type == MappingType .ASYMMETRIC .name
951
- scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
955
+ scale = (max_val_pos - min_val_neg ) / torch .tensor (
956
+ float (quant_max - quant_min ), dtype = scale_dtype , device = scale_device
957
+ )
952
958
scale = torch .clamp (scale , min = eps )
953
959
if zero_point_domain == ZeroPointDomain .NONE .name :
954
960
zero_point = None
You can’t perform that action at this time.
0 commit comments