Skip to content

Commit b10c3cc

Browse files
committed
Fix Per Row scaling for inference
stack-info: PR: #2253, branch: drisspg/stack/56
1 parent a776b1f commit b10c3cc

File tree

6 files changed

+194
-88
lines changed

6 files changed

+194
-88
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,55 @@ def test_fp8_weight_dimension_warning(self):
297297
@unittest.skipIf(
298298
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299299
)
300-
def test_mm_float8dq(self):
300+
@common_utils.parametrize(
301+
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
302+
)
303+
@common_utils.parametrize(
304+
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
305+
) # fmt: skip
306+
@common_utils.parametrize("bias", [True, False])
307+
def test_mm_float8dq_per_row(
308+
self, in_features, out_features, leading_shape, bias: bool
309+
):
301310
device = "cuda"
302311
dtype = torch.bfloat16
303-
weight = torch.randn(512, 1024).to(device).to(dtype)
304-
weight = weight.t()
305-
306-
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307-
l.weight = torch.nn.Parameter(weight)
308-
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309-
# weight shape: 1024 x 512
310-
weight = l.weight
311-
312-
input = torch.randn(1, 512, device=device, dtype=dtype)
313-
# make sure it runs
314-
torch.nn.functional.linear(input, weight)
312+
input_shape = leading_shape + (in_features,)
313+
314+
ref_linear = (
315+
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
316+
)
317+
test_linear = copy.deepcopy(ref_linear)
318+
quantize_(
319+
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
320+
)
321+
322+
quant_weight = test_linear.weight
323+
324+
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
325+
weight_impl = quant_weight.original_weight_tensor.tensor_impl
326+
327+
self.assertTrue(hasattr(weight_impl, "float8_data"))
328+
self.assertTrue(hasattr(weight_impl, "scale"))
329+
self.assertFalse(weight_impl.transposed)
330+
331+
# Verify scale shape for row-wise quantization
332+
expected_scale_shape = (out_features, 1)
333+
actual_scale_shape = weight_impl.scale.shape
334+
self.assertEqual(actual_scale_shape, expected_scale_shape)
335+
336+
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
337+
338+
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
339+
340+
with torch.no_grad():
341+
ref_output = ref_linear(input_tensor)
342+
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)
343+
344+
expected_output_shape = input_tensor.shape[:-1] + (out_features,)
345+
self.assertEqual(quant_output.shape, expected_output_shape)
346+
347+
error = compute_error(ref_output, quant_output)
348+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
315349

316350

317351
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
466-
scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
465+
scale = choose_qparams_affine_float8(
466+
input_float, float8_dtype=target_dtype, block_size=block_size
467+
)
467468
data = quantize_affine_float8(input_float, scale, target_dtype)
468-
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
503503
input_float,
504504
scale,
505505
target_dtype,
506-
scale_dtype,
507506
)
508507

509508
data, scale, zero_point = _layout.post_process(

torchao/dtypes/floatx/float8_layout.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,32 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
195195
elif func is aten.slice.Tensor:
196196
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
197197
if dim == 0:
198-
# TODO: scale replecation should be dependent on block size
199-
if self.scale.ndim == 1:
198+
if self.scale.ndim == 0 or (
199+
self.scale.ndim == 1 and self.scale.size(0) == 1
200+
):
201+
# Per Tensor
200202
return return_and_correct_aliasing(
201203
func,
202204
args,
203205
kwargs,
204-
args[0]._apply_fn_to_data(
205-
lambda x: aten.slice.Tensor(x, dim, start, end, step)
206+
Float8AQTTensorImpl(
207+
aten.slice.Tensor(self.float8_data, dim, start, end, step),
208+
self.scale,
209+
False,
210+
self._layout,
206211
),
207212
)
208-
elif self.scale.ndim == 0:
213+
elif self.scale.ndim == 2:
214+
# TODO: scale replecation should be dependent on block size
209215
return return_and_correct_aliasing(
210216
func,
211217
args,
212218
kwargs,
213-
Float8AQTTensorImpl(
214-
aten.slice.Tensor(self.float8_data, dim, start, end, step),
215-
self.scale,
216-
None,
217-
self._layout,
219+
args[0]._apply_fn_to_data(
220+
lambda x: aten.slice.Tensor(x, dim, start, end, step)
218221
),
219222
)
223+
220224
else:
221225
raise NotImplementedError(
222226
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
@@ -333,13 +337,12 @@ def _linear_fp8_act_fp8_weight_impl(
333337
input_scale = input_tensor.tensor_impl.scale
334338
# Handle case where input tensor is more than 2D
335339
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
336-
337340
# Handle rowwise case
338341
if _is_rowwise_scaled(weight_tensor):
339342
assert _is_rowwise_scaled(input_tensor), (
340343
"Input tensor must be rowwise block size"
341344
)
342-
w_scale = w_scale.unsqueeze(-1).T
345+
w_scale = w_scale.T
343346
input_scale = preprocess_scale(input_scale, input_tensor.shape)
344347

345348
# Preprocess data

torchao/float8/inference.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,20 @@
77
Defines an nn module designed to be used during inference
88
"""
99

10-
from typing import NamedTuple, Optional, Tuple
10+
from typing import NamedTuple, Optional, Tuple, Union
1111

1212
import torch
1313

1414
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
15+
from torchao.quantization.granularity import (
16+
PerRow,
17+
PerTensor,
18+
)
19+
from torchao.utils import (
20+
is_MI300,
21+
is_sm_at_least_89,
22+
is_sm_at_least_90,
23+
)
1524

1625
Tensor = torch.Tensor
1726

@@ -106,3 +115,66 @@ def _is_rowwise_scaled(x) -> bool:
106115
x: AffineQuantizedTensor tensor
107116
"""
108117
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
118+
119+
120+
FP8Granularity = Union[PerTensor, PerRow]
121+
122+
123+
def _normalize_granularity(
124+
granularity: Optional[
125+
Union[
126+
FP8Granularity,
127+
Tuple[FP8Granularity, FP8Granularity],
128+
list[FP8Granularity],
129+
]
130+
],
131+
) -> Tuple[FP8Granularity, FP8Granularity]:
132+
processed_granularity = None
133+
if granularity is None:
134+
processed_granularity = (PerTensor(), PerTensor())
135+
elif isinstance(granularity, (PerTensor, PerRow)):
136+
processed_granularity = (granularity, granularity)
137+
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
138+
if not (
139+
isinstance(granularity[0], (PerTensor, PerRow))
140+
and isinstance(granularity[1], (PerTensor, PerRow))
141+
):
142+
raise ValueError(
143+
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
144+
)
145+
if not isinstance(granularity[0], type(granularity[1])):
146+
raise ValueError(
147+
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
148+
)
149+
processed_granularity = tuple(granularity)
150+
else:
151+
raise ValueError(
152+
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
153+
)
154+
return processed_granularity
155+
156+
157+
def _check_hardware_support(
158+
granularities: Tuple[FP8Granularity, FP8Granularity],
159+
) -> None:
160+
"""
161+
Validate that the hardware supports the requested granularities.
162+
163+
Args:
164+
granularities: Tuple of (activation_granularity, weight_granularity)
165+
166+
Raises:
167+
AssertionError: If hardware doesn't support the requested granularity
168+
ValueError: If invalid granularity type is provided
169+
"""
170+
for _granularity in granularities:
171+
if isinstance(_granularity, PerTensor):
172+
assert is_sm_at_least_89() or is_MI300(), (
173+
"PerTensor quantization only works for CUDA>=8.9 and MI300+"
174+
)
175+
elif isinstance(_granularity, PerRow):
176+
assert is_sm_at_least_90() or is_MI300(), (
177+
"PerRow quantization only works for CUDA>=9.0 and MI300+"
178+
)
179+
else:
180+
raise ValueError(f"Invalid granularity type: {_granularity}")

torchao/quantization/quant_api.py

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@
5454
from torchao.dtypes.utils import Layout
5555
from torchao.float8.config import e4m3_dtype, e5m2_dtype
5656
from torchao.float8.float8_linear import Float8Linear
57-
from torchao.float8.inference import Float8MMConfig
57+
from torchao.float8.inference import (
58+
Float8MMConfig,
59+
FP8Granularity,
60+
_check_hardware_support,
61+
_normalize_granularity,
62+
)
5863
from torchao.quantization.linear_activation_weight_observed_tensor import (
5964
LinearActivationWeightObservedTensor,
6065
)
@@ -1431,56 +1436,9 @@ def _float8_weight_only_transform(
14311436
return module
14321437

14331438

1434-
_fp8_granularities = Union[PerTensor, PerRow]
1435-
1436-
1437-
# Validate and process granularity input
1438-
def _normalize_granularity(
1439-
granularity: Optional[
1440-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1441-
],
1442-
) -> Tuple[_fp8_granularities, _fp8_granularities]:
1443-
processed_granularity = None
1444-
if granularity is None:
1445-
processed_granularity = (PerTensor(), PerTensor())
1446-
elif isinstance(granularity, (PerTensor, PerRow)):
1447-
processed_granularity = (granularity, granularity)
1448-
elif isinstance(granularity, tuple) and len(granularity) == 2:
1449-
if not (
1450-
isinstance(granularity[0], (PerTensor, PerRow))
1451-
and isinstance(granularity[1], (PerTensor, PerRow))
1452-
):
1453-
raise ValueError(
1454-
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
1455-
)
1456-
if not isinstance(granularity[0], type(granularity[1])):
1457-
raise ValueError(
1458-
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
1459-
)
1460-
processed_granularity = granularity
1461-
else:
1462-
raise ValueError(
1463-
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
1464-
)
1465-
# Validate granularity with supported Hardware
1466-
for _granularity in processed_granularity:
1467-
if isinstance(_granularity, PerTensor):
1468-
assert is_sm_at_least_89() or is_MI300(), (
1469-
"PerTensor quantization only works for CUDA>=8.9 and MI300+"
1470-
)
1471-
elif isinstance(_granularity, PerRow):
1472-
assert is_sm_at_least_90() or is_MI300(), (
1473-
"PerRow quantization only works for CUDA>=9.0 and MI300+"
1474-
)
1475-
else:
1476-
raise ValueError(f"Invalid granularity type: {_granularity}")
1477-
1478-
return processed_granularity
1479-
1480-
14811439
def _input_activation_quant_func_fp8(
14821440
x: torch.Tensor,
1483-
activation_granularity: _fp8_granularities,
1441+
activation_granularity: FP8Granularity,
14841442
activation_dtype: torch.dtype,
14851443
scale: Optional[torch.Tensor] = None,
14861444
zero_point: Optional[torch.Tensor] = None,
@@ -1567,7 +1525,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15671525
activation_dtype: torch.dtype = e4m3_dtype
15681526
weight_dtype: torch.dtype = e4m3_dtype
15691527
granularity: Optional[
1570-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1528+
Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]]
15711529
] = None
15721530
mm_config: Optional[Float8MMConfig] = None
15731531
set_inductor_config: bool = True
@@ -1576,6 +1534,11 @@ def __post_init__(self):
15761534
if self.mm_config is None:
15771535
self.mm_config = Float8MMConfig(use_fast_accum=True)
15781536

1537+
activation_granularity, weight_granularity = _normalize_granularity(
1538+
self.granularity
1539+
)
1540+
self.granularity = (activation_granularity, weight_granularity)
1541+
15791542

15801543
# for bc
15811544
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
@@ -1587,7 +1550,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15871550
granularity = config.granularity
15881551
mm_config = config.mm_config
15891552

1590-
activation_granularity, weight_granularity = _normalize_granularity(granularity)
1553+
# Ensure works on device
1554+
_check_hardware_support(granularity)
1555+
activation_granularity, weight_granularity = granularity
15911556

15921557
if not _fp8_mm_compat(weight):
15931558
# TODO(future PR): this should really throw an exception instead of silently
@@ -1704,7 +1669,7 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
17041669
activation_dtype: torch.dtype = e4m3_dtype
17051670
weight_dtype: torch.dtype = e4m3_dtype
17061671
granularity: Optional[
1707-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1672+
Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]]
17081673
] = None
17091674
mm_config: Optional[Float8MMConfig] = None
17101675
set_inductor_config: bool = True

0 commit comments

Comments
 (0)