Skip to content

Commit ed81130

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix device and dtype discrepancy in _choose_qparams_affine (#2210)
Summary: Pull Request resolved: #2210 Differential Revision: D74446877
1 parent 554cb60 commit ed81130

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,9 @@ def _choose_qparams_affine(
873873
f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
874874
)
875875

876+
scale_device = None
876877
if input is not None:
878+
scale_device = input.device
877879
if scale_dtype is None:
878880
scale_dtype = input.dtype
879881
if eps is None:
@@ -902,6 +904,8 @@ def _choose_qparams_affine(
902904
if eps is None:
903905
eps = torch.finfo(min_val.dtype).eps
904906

907+
scale_device = min_val.device
908+
905909
if preserve_zero:
906910
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
907911
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
@@ -948,7 +952,9 @@ def _choose_qparams_affine(
948952
scale = torch.clamp(scale, min=eps)
949953
else:
950954
assert mapping_type == MappingType.ASYMMETRIC.name
951-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
955+
scale = (max_val_pos - min_val_neg) / torch.tensor(
956+
float(quant_max - quant_min), dtype=scale_dtype, device=scale_device
957+
)
952958
scale = torch.clamp(scale, min=eps)
953959
if zero_point_domain == ZeroPointDomain.NONE.name:
954960
zero_point = None

0 commit comments

Comments
 (0)