Skip to content

Commit cdc082f

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix device and dtype discrepancy in _choose_qparams_affine (#2210)
Summary: Pull Request resolved: #2210 Differential Revision: D74446877
1 parent 554cb60 commit cdc082f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,9 @@ def _choose_qparams_affine(
873873
f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
874874
)
875875

876+
scale_device = None
876877
if input is not None:
878+
scale_device = input.device
877879
if scale_dtype is None:
878880
scale_dtype = input.dtype
879881
if eps is None:
@@ -901,6 +903,8 @@ def _choose_qparams_affine(
901903
scale_dtype = min_val.dtype
902904
if eps is None:
903905
eps = torch.finfo(min_val.dtype).eps
906+
907+
scale_device = min_val.device
904908

905909
if preserve_zero:
906910
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
@@ -948,7 +952,9 @@ def _choose_qparams_affine(
948952
scale = torch.clamp(scale, min=eps)
949953
else:
950954
assert mapping_type == MappingType.ASYMMETRIC.name
951-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
955+
scale = (max_val_pos - min_val_neg) / torch.tensor(
956+
float(quant_max - quant_min), dtype=scale_dtype, device=scale_device
957+
)
952958
scale = torch.clamp(scale, min=eps)
953959
if zero_point_domain == ZeroPointDomain.NONE.name:
954960
zero_point = None

0 commit comments

Comments
 (0)