Skip to content

Commit a2f2f09

Browse files
committed
Fix Per Row scaling for inference
stack-info: PR: #2253, branch: drisspg/stack/56
1 parent adc78b7 commit a2f2f09

File tree

5 files changed

+109
-25
lines changed

5 files changed

+109
-25
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,67 @@ def test_fp8_weight_dimension_warning(self):
297297
@unittest.skipIf(
298298
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299299
)
300-
def test_mm_float8dq(self):
300+
@common_utils.parametrize(
301+
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
302+
)
303+
@common_utils.parametrize(
304+
"leading_shape",
305+
[
306+
(1,),
307+
(8,),
308+
(16,),
309+
(
310+
2,
311+
8,
312+
),
313+
(
314+
2,
315+
2,
316+
16,
317+
),
318+
],
319+
)
320+
@common_utils.parametrize("bias", [True, False])
321+
def test_mm_float8dq(self, in_features, out_features, leading_shape, bias: bool):
301322
device = "cuda"
302323
dtype = torch.bfloat16
303-
weight = torch.randn(512, 1024).to(device).to(dtype)
304-
weight = weight.t()
305-
306-
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307-
l.weight = torch.nn.Parameter(weight)
308-
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309-
# weight shape: 1024 x 512
310-
weight = l.weight
311-
312-
input = torch.randn(1, 512, device=device, dtype=dtype)
313-
# make sure it runs
314-
torch.nn.functional.linear(input, weight)
324+
input_shape = leading_shape + (in_features,)
325+
326+
ref_linear = (
327+
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
328+
)
329+
test_linear = copy.deepcopy(ref_linear)
330+
quantize_(
331+
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
332+
)
333+
334+
quant_weight = test_linear.weight
335+
336+
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
337+
weight_impl = quant_weight.original_weight_tensor.tensor_impl
338+
339+
self.assertTrue(hasattr(weight_impl, "float8_data"))
340+
self.assertTrue(hasattr(weight_impl, "scale"))
341+
self.assertFalse(weight_impl.transposed)
342+
343+
# Verify scale shape for row-wise quantization
344+
expected_scale_shape = (out_features, 1)
345+
actual_scale_shape = weight_impl.scale.shape
346+
self.assertEqual(actual_scale_shape, expected_scale_shape)
347+
348+
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
349+
350+
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
351+
352+
with torch.no_grad():
353+
ref_output = ref_linear(input_tensor)
354+
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)
355+
356+
expected_output_shape = input_tensor.shape[:-1] + (out_features,)
357+
self.assertEqual(quant_output.shape, expected_output_shape)
358+
359+
error = compute_error(ref_output, quant_output)
360+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
315361

316362

317363
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
466-
scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
465+
scale = choose_qparams_affine_float8(
466+
input_float, float8_dtype=target_dtype, block_size=block_size
467+
)
467468
data = quantize_affine_float8(input_float, scale, target_dtype)
468-
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
503503
input_float,
504504
scale,
505505
target_dtype,
506-
scale_dtype,
507506
)
508507

509508
data, scale, zero_point = _layout.post_process(

torchao/dtypes/floatx/float8_layout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,12 @@ def _linear_fp8_act_fp8_weight_impl(
333333
input_scale = input_tensor.tensor_impl.scale
334334
# Handle case where input tensor is more than 2D
335335
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
336-
337336
# Handle rowwise case
338337
if _is_rowwise_scaled(weight_tensor):
339338
assert _is_rowwise_scaled(input_tensor), (
340339
"Input tensor must be rowwise block size"
341340
)
342-
w_scale = w_scale.unsqueeze(-1).T
341+
w_scale = w_scale.T
343342
input_scale = preprocess_scale(input_scale, input_tensor.shape)
344343

345344
# Preprocess data

torchao/quantization/quant_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,11 @@ def __post_init__(self):
15761576
if self.mm_config is None:
15771577
self.mm_config = Float8MMConfig(use_fast_accum=True)
15781578

1579+
activation_granularity, weight_granularity = _normalize_granularity(
1580+
self.granularity
1581+
)
1582+
self.granularity = (activation_granularity, weight_granularity)
1583+
15791584

15801585
# for bc
15811586
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig

torchao/quantization/quant_primitives.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,7 @@ def choose_qparams_affine_float8(
19701970
tensor: torch.Tensor,
19711971
float8_dtype: torch.dtype = torch.float8_e4m3fn,
19721972
scale_dtype: torch.dtype = torch.float32,
1973+
block_size: Optional[Tuple[int, ...]] = None,
19731974
) -> torch.Tensor:
19741975
"""
19751976
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -1978,12 +1979,27 @@ def choose_qparams_affine_float8(
19781979
tensor (torch.Tensor): Input tensor to be quantized.
19791980
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
19801981
"""
1982+
quant_max = torch.finfo(float8_dtype).max
19811983
# only tensorwise scaling is supported for now:
1982-
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1983-
min_val_neg = torch.min(tensor)
1984-
max_val_pos = torch.max(tensor)
1985-
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1986-
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1984+
if block_size is None:
1985+
max_abs = tensor.abs().max()
1986+
scale = max_abs / quant_max
1987+
else:
1988+
shape_for_reduction, reduction_dims = _get_reduction_params(
1989+
block_size, tensor.shape
1990+
)
1991+
tensor_reshaped = tensor.view(shape_for_reduction)
1992+
max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True)
1993+
1994+
scale = max_abs / quant_max
1995+
# Reshape scale back to match the expected output shape
1996+
# The scale tensor should have the same shape as the input divided by block_size
1997+
output_shape = [
1998+
input_size // block_size[i] if block_size[i] > 1 else input_size
1999+
for i, input_size in enumerate(tensor.shape)
2000+
]
2001+
scale = scale.reshape(output_shape)
2002+
19872003
return scale.to(dtype=scale_dtype)
19882004

19892005

@@ -2027,5 +2043,24 @@ def dequantize_affine_float8(
20272043
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
20282044
# In order to match numerics between eager and compile, we upcast manually here.
20292045
fp8_tensor = tensor.to(torch.float32)
2030-
hp_tensor = fp8_tensor * scale
2046+
# For block-wise quantization, we need to broadcast the scale to match tensor dimensions
2047+
if scale.shape != tensor.shape:
2048+
# Calculate the block size from the shape difference
2049+
block_size = tuple(
2050+
tensor.shape[i] // scale.shape[i]
2051+
if scale.shape[i] != tensor.shape[i]
2052+
else 1
2053+
for i in range(len(tensor.shape))
2054+
)
2055+
2056+
scale_expanded = scale
2057+
for i in range(len(tensor.shape)):
2058+
if block_size[i] > 1:
2059+
# Repeat the scale values for each block
2060+
scale_expanded = scale_expanded.repeat_interleave(block_size[i], dim=i)
2061+
else:
2062+
# Tensor-wise quantization - scale already matches
2063+
scale_expanded = scale
2064+
2065+
hp_tensor = fp8_tensor * scale_expanded
20312066
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)