Skip to content

Commit b3eb6a7

Browse files
committed
[Float8Quant] Add rowwise scaling option to float8 dyanmic quant
stack-info: PR: #819, branch: drisspg/stack/11
1 parent 65d86c6 commit b3eb6a7

File tree

6 files changed

+171
-63
lines changed

6 files changed

+171
-63
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ include = [
1010
"torchao/float8/float8_tensor.py",
1111
"torchao/quantization/linear_activation_weight_observer.py",
1212
"test/quantization/test_observer.py",
13+
"test/dtypes/test_affine_quantized_float.py",
1314
]

test/dtypes/test_affine_quantized_float.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
11
from torchao.utils import (
22
TORCH_VERSION_AT_LEAST_2_5,
3-
unwrap_tensor_subclass,
43
)
54
import pytest
65

76
if not TORCH_VERSION_AT_LEAST_2_5:
87
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
98

10-
from numpy import full
11-
from torch.testing._internal.common_utils import (
12-
run_tests,
13-
)
149
from torch._inductor.test_case import TestCase as InductorTestCase
1510
from torch.testing._internal import common_utils
16-
from torch._dynamo.testing import CompileCounterWithBackend
1711

1812
from torchao.quantization import (
1913
quantize_,
2014
float8_weight_only,
2115
float8_dynamic_activation_float8_weight,
2216
)
17+
from torchao.quantization.observer import PerTensor, PerRow
2318
from torchao.float8.float8_utils import compute_error
2419
import torch
2520
import unittest
2621
import pytest
27-
import tempfile
2822
import copy
2923
import random
30-
31-
from unittest.mock import patch
24+
from functools import partial
25+
from typing import Tuple
26+
from contextlib import nullcontext
3227

3328

3429
random.seed(0)
@@ -56,6 +51,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
5651
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
5752
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
5853
@common_utils.parametrize("compile", [True, False])
54+
@common_utils.parametrize(
55+
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
56+
)
5957
# Inputs are (M,..), K, N
6058
@common_utils.parametrize(
6159
"sizes",
@@ -68,33 +66,49 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
6866
],
6967
)
7068
def test_fp8_linear_variants(
71-
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
69+
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
7270
):
73-
M, N, K = sizes
74-
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
75-
76-
mode_map = {
77-
"dynamic": float8_dynamic_activation_float8_weight,
78-
"weight-only": float8_weight_only,
79-
}
80-
81-
# Create a linear layer with bfloat16 dtype
82-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
83-
84-
quantized_model = copy.deepcopy(model)
85-
factory = mode_map[mode]()
86-
quantize_(model, factory)
87-
88-
if compile:
89-
quantized_model = torch.compile(quantized_model, fullgraph=True)
90-
91-
output_original = model(input_tensor)
92-
output_quantized = quantized_model(input_tensor)
93-
94-
error = compute_error(output_original, output_quantized)
95-
assert (
96-
compute_error(output_original, output_quantized) > 20
97-
), f"Quantization error is too high got a SQNR of {error}"
71+
raises = (
72+
isinstance(granularity, PerRow)
73+
and mode == "dynamic"
74+
and dtype != torch.bfloat16
75+
)
76+
context = (
77+
nullcontext()
78+
if not raises
79+
else pytest.raises(
80+
AssertionError,
81+
match="PerRow quantization only works for bfloat16 precision",
82+
)
83+
)
84+
with context:
85+
M, N, K = sizes
86+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
87+
88+
mode_map = {
89+
"dynamic": partial(
90+
float8_dynamic_activation_float8_weight, granularity=granularity
91+
),
92+
"weight-only": float8_weight_only,
93+
}
94+
95+
# Create a linear layer with bfloat16 dtype
96+
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
97+
98+
quantized_model = copy.deepcopy(model)
99+
factory = mode_map[mode]()
100+
quantize_(model, factory)
101+
102+
if compile:
103+
quantized_model = torch.compile(quantized_model, fullgraph=True)
104+
105+
output_original = model(input_tensor)
106+
output_quantized = quantized_model(input_tensor)
107+
108+
error = compute_error(output_original, output_quantized)
109+
assert (
110+
compute_error(output_original, output_quantized) > 20
111+
), f"Quantization error is too high got a SQNR of {error}"
98112

99113

100114
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
is_device,
2828
get_out_shape,
2929
)
30+
from torchao.float8.inference import (
31+
preprocess_data,
32+
Float8MMConfig,
33+
addmm_float8_unwrapped_inference,
34+
_is_rowwise_scaled
35+
)
3036
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3137
from dataclasses import dataclass
3238
from torchao.utils import (
@@ -1355,53 +1361,61 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):
13551361

13561362
return out.view(*act.shape[:-1], out_dim).to(act.dtype)
13571363

1358-
def _linear_fp_act_fp8_tensor_wise_weight_check(
1364+
def _linear_fp_act_fp8_weight_check(
13591365
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
13601366
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
13611367
bias: Optional[torch.Tensor],
13621368
) -> bool:
1363-
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
1369+
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
13641370
return (
13651371
isinstance(aqt, AffineQuantizedTensor) and
13661372
isinstance(aqt.layout_type, Float8LayoutType)
13671373
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
1368-
and aqt.shape == aqt.block_size
1374+
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
13691375
)
1370-
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
1376+
return check_aqt(input_tensor) and check_aqt(weight_tensor)
1377+
13711378

1379+
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
1380+
""" Ensures input tensor is correctly formated for _scaled_mm """
1381+
input_scale = input_scale.unsqueeze(-1)
1382+
1383+
if input_scale.dim() > 2:
1384+
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1385+
1386+
return input_scale
13721387

13731388
def _linear_fp_act_fp8_weight_impl(
13741389
input_tensor: AffineQuantizedTensor,
13751390
weight_tensor: AffineQuantizedTensor,
13761391
bias: Optional[torch.Tensor],
13771392
):
13781393
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
1379-
from torchao.float8.inference import (
1380-
preprocess_data,
1381-
Float8MMConfig,
1382-
addmm_float8_unwrapped_inference,
1383-
)
1384-
13851394
scaled_mm_config = weight_tensor.layout_type.mm_config
1386-
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
1395+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
13871396

1397+
# Weight tensor preprocessing
13881398
w_layout = weight_tensor.layout_tensor
1389-
w_data = weight_tensor.layout_tensor.float8_data
1390-
w_data = w_data.T if w_layout.transposed else w_data
1399+
assert not w_layout.transposed, "Weight tensor must be contiguous"
1400+
w_data = w_layout.float8_data
13911401
w_scale = w_layout.scale
1392-
w_scale = w_scale if w_layout.transposed else w_scale
1393-
1394-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
13951402

1403+
# Input tensor preprocessing
13961404
inpt_data = input_tensor.layout_tensor.float8_data
1397-
# Handle case where input tensor is more than 2D
1398-
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
13991405
input_scale = input_tensor.layout_tensor.scale
1400-
if input_scale.dim() > 2:
1401-
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1406+
# Handle case where input tensor is more than 2D
1407+
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
1408+
1409+
# Handle rowwise case
1410+
if _is_rowwise_scaled(weight_tensor):
1411+
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
1412+
w_scale = w_scale.unsqueeze(-1).T
1413+
input_scale = preprocess_scale(input_scale, input_tensor.shape)
14021414

1415+
# Preprocess data
14031416
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
14041417

1418+
# Perform the computation
14051419
return addmm_float8_unwrapped_inference(
14061420
inpt_data,
14071421
input_scale,
@@ -1458,7 +1472,7 @@ def _register_aqt_quantized_linear_dispatches():
14581472
for dispatch_condition, impl in [
14591473
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
14601474
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
1461-
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
1475+
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
14621476
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
14631477
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
14641478
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),

torchao/float8/inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
9797
use_fast_accum=use_fast_accum,
9898
)
9999
return output
100+
101+
102+
def _is_rowwise_scaled(x) -> bool:
103+
"""Checks if an AQT tensor is rowwise scaled
104+
Args:
105+
x: AffineQuantizedTensor tensor
106+
"""
107+
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)

torchao/quantization/observer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ class PerAxis(GranularityType):
5353
"""
5454
axis: int
5555

56+
@dataclass(frozen=True)
57+
class PerRow(GranularityType):
58+
"""
59+
Represents row-wise granularity in quantization.
60+
61+
This is a special case of per-axis quantization and is unique to Float8 matmuls
62+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
63+
is quantized with a block_size of (1, weight.shape[1]).
64+
"""
65+
pass
5666

5767
# borrowed from torch.ao.quantization.observer
5868
class _PartialWrapper:

torchao/quantization/quant_api.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torchao
2020
import torch.nn as nn
2121
import torch.nn.functional as F
22-
from typing import Any, Callable, Union, Dict, Optional
22+
from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple
2323
import types
2424

2525
from torchao.dtypes.uintx.Uintx import UintxLayoutType
@@ -65,6 +65,8 @@
6565
)
6666
from torchao.float8.inference import Float8MMConfig
6767

68+
from torchao.quantization.observer import PerTensor, PerAxis, PerRow
69+
6870
logger = logging.getLogger(__name__)
6971

7072
__all__ = [
@@ -641,17 +643,53 @@ def apply_float8wo_quant(weight):
641643
return _get_linear_subclass_inserter(apply_float8wo_quant)
642644

643645

646+
_fp8_granularities = Literal[PerTensor, PerRow]
647+
648+
649+
# Validate and process granularity input
650+
def _validate_granularity(
651+
granularity: Optional[
652+
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
653+
]
654+
) -> Tuple[_fp8_granularities, _fp8_granularities]:
655+
if granularity is None:
656+
return (PerTensor(), PerTensor())
657+
elif isinstance(granularity, (PerTensor, PerRow)):
658+
return (granularity, granularity)
659+
elif isinstance(granularity, tuple) and len(granularity) == 2:
660+
if not (
661+
isinstance(granularity[0], (PerTensor, PerRow))
662+
and isinstance(granularity[1], (PerTensor, PerRow))
663+
):
664+
raise ValueError(f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported.")
665+
if type(granularity[0]) != type(granularity[1]):
666+
raise ValueError(
667+
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
668+
)
669+
return granularity
670+
else:
671+
raise ValueError(f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported.")
672+
673+
644674
def float8_dynamic_activation_float8_weight(
645675
activation_dtype: torch.dtype = torch.float8_e4m3fn,
646676
weight_dtype: torch.dtype = torch.float8_e4m3fn,
647-
mm_config: Optional[Float8MMConfig] = None
677+
granularity: Optional[
678+
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
679+
] = None,
680+
mm_config: Optional[Float8MMConfig] = None,
648681
):
649682
"""
650-
Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers.
683+
Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
651684
652685
Args:
653686
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
654687
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
688+
granularity:
689+
The granularity for quantization. Can be either a single granularity (applied to both
690+
activations and weights) or a tuple of two granularities (one for activations, one for weights).
691+
If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And
692+
only PerTensor and PerRow are supported.
655693
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
656694
657695
"""
@@ -660,23 +698,46 @@ def float8_dynamic_activation_float8_weight(
660698
if mm_config is None:
661699
mm_config = Float8MMConfig(use_fast_accum=True)
662700

663-
#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
701+
activation_granularity, weight_granularity = _validate_granularity(granularity)
702+
703+
def get_block_size(x: torch.Tensor, granularity: _fp8_granularities):
704+
if isinstance(granularity, PerTensor):
705+
return x.shape
706+
elif isinstance(granularity, PerRow):
707+
return (1,) * (x.dim() - 1) + (x.shape[-1],)
708+
else:
709+
raise ValueError(f"Unsupported granularity: {granularity}")
710+
664711
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
712+
if isinstance(weight_granularity, PerRow):
713+
assert (
714+
weight.dtype == torch.bfloat16
715+
), "PerRow quantization only works for bfloat16 precision input weight"
716+
717+
block_size = get_block_size(weight, weight_granularity)
665718
quantized_weight = to_affine_quantized_floatx(
666719
input_float=weight,
667-
block_size=weight.shape,
720+
block_size=block_size,
668721
target_dtype=weight_dtype,
669722
scale_dtype=torch.float32,
670723
layout_type=Float8LayoutType(mm_config=mm_config),
671724
)
672725

673726
def input_quant_func(x: torch.Tensor):
727+
if isinstance(activation_granularity, PerRow):
728+
assert (
729+
x.dtype == torch.bfloat16
730+
), "PerRow quantization only works for bfloat16 precision input activation"
731+
732+
block_size = get_block_size(x, activation_granularity)
674733
activation = to_affine_quantized_floatx(
675734
input_float=x,
676-
block_size=x.shape,
735+
block_size=block_size,
677736
target_dtype=activation_dtype,
678737
scale_dtype=torch.float32,
679-
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
738+
layout_type=Float8LayoutType(
739+
mm_config=None
740+
), # Config is stored on weight
680741
)
681742
return activation
682743

0 commit comments

Comments
 (0)