Skip to content

Commit ee96b19

Browse files
committed
Add a way to do power of 2 scaling
stack-info: PR: #2256, branch: drisspg/stack/57
1 parent 5d02444 commit ee96b19

File tree

4 files changed

+73
-4
lines changed

4 files changed

+73
-4
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
choose_qparams_affine,
4545
)
4646
from torchao.utils import (
47+
TORCH_VERSION_AT_LEAST_2_8,
4748
is_sm_at_least_89,
4849
is_sm_at_least_90,
4950
)
@@ -345,6 +346,63 @@ def test_mm_float8dq(self, in_features, out_features, leading_shape, bias: bool)
345346
error = compute_error(ref_output, quant_output)
346347
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
347348

349+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
350+
@unittest.skipIf(
351+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
352+
)
353+
@common_utils.parametrize(
354+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
355+
)
356+
@unittest.skipIf(
357+
not TORCH_VERSION_AT_LEAST_2_8, "Requires PyTorch 2.8+ with e8m0 support"
358+
)
359+
def test_fp8_e8m0_scale_dtype(self, granularity):
360+
"""Test float8 quantization with e8m0 scale dtype on PyTorch 2.8+"""
361+
device = "cuda"
362+
dtype = torch.bfloat16
363+
in_features, out_features = 256, 512
364+
365+
# Create model
366+
model = ToyLinearModel(in_features, out_features).to(device).to(dtype)
367+
quant_model = copy.deepcopy(model)
368+
369+
# Create config with e8m0 scale dtype
370+
config = Float8DynamicActivationFloat8WeightConfig(
371+
granularity=granularity, scale_dtype=torch.float8_e8m0fnu
372+
)
373+
374+
# Quantize the model
375+
quantize_(quant_model, config)
376+
377+
# Verify that the scale dtype is correctly set
378+
for layer_name in ["linear1", "linear2"]:
379+
layer = getattr(quant_model, layer_name)
380+
weight_impl = layer.weight.original_weight_tensor.tensor_impl
381+
382+
# All though we specify w/ e8m0 we still cast to fp32
383+
self.assertEqual(weight_impl.scale.dtype, torch.float32)
384+
385+
# Verify scale is power of 2 (requirement for e8m0)
386+
scale_values = weight_impl.scale.float()
387+
log2_scales = torch.log2(scale_values)
388+
self.assertTrue(
389+
torch.allclose(log2_scales, torch.round(log2_scales), atol=0),
390+
"e8m0 scales should be powers of 2",
391+
)
392+
393+
# Test forward pass
394+
input_tensor = torch.randn(32, in_features, device=device, dtype=dtype)
395+
396+
with torch.no_grad():
397+
output = model(input_tensor)
398+
output_quant = quant_model(input_tensor)
399+
400+
# Verify output shape and that computation completes without error
401+
expected_shape = (32, in_features) # ToyLinearModel returns to original size
402+
self.assertEqual(output.shape, expected_shape)
403+
error = compute_error(output, output_quant)
404+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
405+
348406

349407
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
350408

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def from_hp_to_floatx(
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465465
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
466+
input_float,
467+
float8_dtype=target_dtype,
468+
block_size=block_size,
469+
scale_dtype=scale_dtype,
467470
)
468471
data = quantize_affine_float8(input_float, scale, target_dtype)
469472
data, scale, zero_point = _layout.post_process(

torchao/quantization/quant_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14071407
input_float=weight,
14081408
block_size=block_size,
14091409
target_dtype=config.weight_dtype,
1410-
scale_dtype=None,
1410+
scale_dtype=torch.float32,
14111411
_layout=Float8Layout(mm_config=None),
14121412
)
14131413
return new_weight
@@ -1564,6 +1564,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15641564
only PerTensor and PerRow are supported.
15651565
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15661566
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1567+
scale_dtype: By default we set to fp32, if a user is on 12.8 and sets it to e8m0 we well ensure power of 2 scaling
15671568
15681569
"""
15691570

@@ -1574,6 +1575,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15741575
] = None
15751576
mm_config: Optional[Float8MMConfig] = None
15761577
set_inductor_config: bool = True
1578+
scale_dtype: torch.dtype = torch.float32
15771579

15781580
def __post_init__(self):
15791581
if self.mm_config is None:
@@ -1594,6 +1596,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15941596
weight_dtype = config.weight_dtype
15951597
granularity = config.granularity
15961598
mm_config = config.mm_config
1599+
scale_dtype = config.scale_dtype
15971600

15981601
activation_granularity, weight_granularity = _normalize_granularity(granularity)
15991602

@@ -1613,7 +1616,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16131616
input_float=weight,
16141617
block_size=block_size,
16151618
target_dtype=weight_dtype,
1616-
scale_dtype=torch.float32,
1619+
scale_dtype=scale_dtype,
16171620
_layout=Float8Layout(mm_config=mm_config),
16181621
)
16191622

torchao/quantization/quant_primitives.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2002,7 +2002,12 @@ def choose_qparams_affine_float8(
20022002
]
20032003
scale = scale.reshape(output_shape)
20042004

2005-
return scale.to(dtype=scale_dtype)
2005+
if scale_dtype is not torch.float32:
2006+
# Shielding for Version > 2.8
2007+
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
2008+
scale = torch.exp2(torch.round(torch.log2(scale)))
2009+
2010+
return scale.to(dtype=torch.float32)
20062011

20072012

20082013
def quantize_affine_float8(

0 commit comments

Comments
 (0)