Skip to content

Commit 4d7c98f

Browse files
committed
Fix Per Row scaling for inference
stack-info: PR: #2253, branch: drisspg/stack/56
1 parent a776b1f commit 4d7c98f

File tree

5 files changed

+118
-39
lines changed

5 files changed

+118
-39
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,53 @@ def test_fp8_weight_dimension_warning(self):
297297
@unittest.skipIf(
298298
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299299
)
300-
def test_mm_float8dq(self):
300+
@common_utils.parametrize(
301+
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
302+
)
303+
@common_utils.parametrize(
304+
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
305+
) # fmt: skip
306+
@common_utils.parametrize("bias", [True, False])
307+
def test_mm_float8dq(self, in_features, out_features, leading_shape, bias: bool):
301308
device = "cuda"
302309
dtype = torch.bfloat16
303-
weight = torch.randn(512, 1024).to(device).to(dtype)
304-
weight = weight.t()
305-
306-
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307-
l.weight = torch.nn.Parameter(weight)
308-
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309-
# weight shape: 1024 x 512
310-
weight = l.weight
311-
312-
input = torch.randn(1, 512, device=device, dtype=dtype)
313-
# make sure it runs
314-
torch.nn.functional.linear(input, weight)
310+
input_shape = leading_shape + (in_features,)
311+
312+
ref_linear = (
313+
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
314+
)
315+
test_linear = copy.deepcopy(ref_linear)
316+
quantize_(
317+
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
318+
)
319+
320+
quant_weight = test_linear.weight
321+
322+
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
323+
weight_impl = quant_weight.original_weight_tensor.tensor_impl
324+
325+
self.assertTrue(hasattr(weight_impl, "float8_data"))
326+
self.assertTrue(hasattr(weight_impl, "scale"))
327+
self.assertFalse(weight_impl.transposed)
328+
329+
# Verify scale shape for row-wise quantization
330+
expected_scale_shape = (out_features, 1)
331+
actual_scale_shape = weight_impl.scale.shape
332+
self.assertEqual(actual_scale_shape, expected_scale_shape)
333+
334+
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
335+
336+
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
337+
338+
with torch.no_grad():
339+
ref_output = ref_linear(input_tensor)
340+
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)
341+
342+
expected_output_shape = input_tensor.shape[:-1] + (out_features,)
343+
self.assertEqual(quant_output.shape, expected_output_shape)
344+
345+
error = compute_error(ref_output, quant_output)
346+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
315347

316348

317349
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
466-
scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
465+
scale = choose_qparams_affine_float8(
466+
input_float, float8_dtype=target_dtype, block_size=block_size
467+
)
467468
data = quantize_affine_float8(input_float, scale, target_dtype)
468-
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
503503
input_float,
504504
scale,
505505
target_dtype,
506-
scale_dtype,
507506
)
508507

509508
data, scale, zero_point = _layout.post_process(

torchao/dtypes/floatx/float8_layout.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,32 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
195195
elif func is aten.slice.Tensor:
196196
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
197197
if dim == 0:
198-
# TODO: scale replecation should be dependent on block size
199-
if self.scale.ndim == 1:
198+
if self.scale.ndim == 0 or (
199+
self.scale.ndim == 1 and self.scale.size(0) == 1
200+
):
201+
# Per Tensor
200202
return return_and_correct_aliasing(
201203
func,
202204
args,
203205
kwargs,
204-
args[0]._apply_fn_to_data(
205-
lambda x: aten.slice.Tensor(x, dim, start, end, step)
206+
Float8AQTTensorImpl(
207+
aten.slice.Tensor(self.float8_data, dim, start, end, step),
208+
self.scale,
209+
False,
210+
self._layout,
206211
),
207212
)
208-
elif self.scale.ndim == 0:
213+
elif self.scale.ndim == 2:
214+
# TODO: scale replecation should be dependent on block size
209215
return return_and_correct_aliasing(
210216
func,
211217
args,
212218
kwargs,
213-
Float8AQTTensorImpl(
214-
aten.slice.Tensor(self.float8_data, dim, start, end, step),
215-
self.scale,
216-
None,
217-
self._layout,
219+
args[0]._apply_fn_to_data(
220+
lambda x: aten.slice.Tensor(x, dim, start, end, step)
218221
),
219222
)
223+
220224
else:
221225
raise NotImplementedError(
222226
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
@@ -333,13 +337,12 @@ def _linear_fp8_act_fp8_weight_impl(
333337
input_scale = input_tensor.tensor_impl.scale
334338
# Handle case where input tensor is more than 2D
335339
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
336-
337340
# Handle rowwise case
338341
if _is_rowwise_scaled(weight_tensor):
339342
assert _is_rowwise_scaled(input_tensor), (
340343
"Input tensor must be rowwise block size"
341344
)
342-
w_scale = w_scale.unsqueeze(-1).T
345+
w_scale = w_scale.T
343346
input_scale = preprocess_scale(input_scale, input_tensor.shape)
344347

345348
# Preprocess data

torchao/quantization/quant_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,18 +1434,21 @@ def _float8_weight_only_transform(
14341434
_fp8_granularities = Union[PerTensor, PerRow]
14351435

14361436

1437-
# Validate and process granularity input
14381437
def _normalize_granularity(
14391438
granularity: Optional[
1440-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1439+
Union[
1440+
_fp8_granularities,
1441+
Tuple[_fp8_granularities, _fp8_granularities],
1442+
list[_fp8_granularities],
1443+
]
14411444
],
14421445
) -> Tuple[_fp8_granularities, _fp8_granularities]:
14431446
processed_granularity = None
14441447
if granularity is None:
14451448
processed_granularity = (PerTensor(), PerTensor())
14461449
elif isinstance(granularity, (PerTensor, PerRow)):
14471450
processed_granularity = (granularity, granularity)
1448-
elif isinstance(granularity, tuple) and len(granularity) == 2:
1451+
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
14491452
if not (
14501453
isinstance(granularity[0], (PerTensor, PerRow))
14511454
and isinstance(granularity[1], (PerTensor, PerRow))
@@ -1457,7 +1460,7 @@ def _normalize_granularity(
14571460
raise ValueError(
14581461
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
14591462
)
1460-
processed_granularity = granularity
1463+
processed_granularity = tuple(granularity)
14611464
else:
14621465
raise ValueError(
14631466
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
@@ -1576,6 +1579,11 @@ def __post_init__(self):
15761579
if self.mm_config is None:
15771580
self.mm_config = Float8MMConfig(use_fast_accum=True)
15781581

1582+
activation_granularity, weight_granularity = _normalize_granularity(
1583+
self.granularity
1584+
)
1585+
self.granularity = (activation_granularity, weight_granularity)
1586+
15791587

15801588
# for bc
15811589
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig

torchao/quantization/quant_primitives.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,20 +1970,38 @@ def choose_qparams_affine_float8(
19701970
tensor: torch.Tensor,
19711971
float8_dtype: torch.dtype = torch.float8_e4m3fn,
19721972
scale_dtype: torch.dtype = torch.float32,
1973+
block_size: Optional[Tuple[int, ...]] = None,
19731974
) -> torch.Tensor:
19741975
"""
19751976
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
19761977
19771978
Args:
19781979
tensor (torch.Tensor): Input tensor to be quantized.
19791980
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1981+
scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32).
1982+
block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used.
19801983
"""
1984+
quant_max = torch.finfo(float8_dtype).max
19811985
# only tensorwise scaling is supported for now:
1982-
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1983-
min_val_neg = torch.min(tensor)
1984-
max_val_pos = torch.max(tensor)
1985-
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1986-
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1986+
if block_size is None:
1987+
max_abs = tensor.abs().max()
1988+
scale = max_abs / quant_max
1989+
else:
1990+
shape_for_reduction, reduction_dims = _get_reduction_params(
1991+
block_size, tensor.shape
1992+
)
1993+
tensor_reshaped = tensor.view(shape_for_reduction)
1994+
max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True)
1995+
1996+
scale = max_abs / quant_max
1997+
# Reshape scale back to match the expected output shape
1998+
# The scale tensor should have the same shape as the input divided by block_size
1999+
output_shape = [
2000+
input_size // block_size[i] if block_size[i] > 1 else input_size
2001+
for i, input_size in enumerate(tensor.shape)
2002+
]
2003+
scale = scale.reshape(output_shape)
2004+
19872005
return scale.to(dtype=scale_dtype)
19882006

19892007

@@ -2027,5 +2045,24 @@ def dequantize_affine_float8(
20272045
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
20282046
# In order to match numerics between eager and compile, we upcast manually here.
20292047
fp8_tensor = tensor.to(torch.float32)
2030-
hp_tensor = fp8_tensor * scale
2048+
# For block-wise quantization, we need to broadcast the scale to match tensor dimensions
2049+
if scale.shape != tensor.shape:
2050+
# Calculate the block size from the shape difference
2051+
block_size = tuple(
2052+
tensor.shape[i] // scale.shape[i]
2053+
if scale.shape[i] != tensor.shape[i]
2054+
else 1
2055+
for i in range(len(tensor.shape))
2056+
)
2057+
2058+
scale_expanded = scale
2059+
for i in range(len(tensor.shape)):
2060+
if block_size[i] > 1:
2061+
# Repeat the scale values for each block
2062+
scale_expanded = scale_expanded.repeat_interleave(block_size[i], dim=i)
2063+
else:
2064+
# Tensor-wise quantization - scale already matches
2065+
scale_expanded = scale
2066+
2067+
hp_tensor = fp8_tensor * scale_expanded
20312068
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)