Skip to content

Commit 1dface2

Browse files
committed
[Float8Quant] Add rowwise scaling option to float8 dyanmic quant
stack-info: PR: #819, branch: drisspg/stack/11
1 parent 65d86c6 commit 1dface2

File tree

6 files changed

+272
-68
lines changed

6 files changed

+272
-68
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ include = [
1010
"torchao/float8/float8_tensor.py",
1111
"torchao/quantization/linear_activation_weight_observer.py",
1212
"test/quantization/test_observer.py",
13+
"test/dtypes/test_affine_quantized_float.py",
1314
]

test/dtypes/test_affine_quantized_float.py

Lines changed: 139 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,30 @@
11
from torchao.utils import (
22
TORCH_VERSION_AT_LEAST_2_5,
3-
unwrap_tensor_subclass,
43
)
54
import pytest
65

76
if not TORCH_VERSION_AT_LEAST_2_5:
87
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
98

10-
from numpy import full
11-
from torch.testing._internal.common_utils import (
12-
run_tests,
13-
)
149
from torch._inductor.test_case import TestCase as InductorTestCase
1510
from torch.testing._internal import common_utils
16-
from torch._dynamo.testing import CompileCounterWithBackend
1711

1812
from torchao.quantization import (
1913
quantize_,
2014
float8_weight_only,
2115
float8_dynamic_activation_float8_weight,
2216
)
17+
from torchao.quantization.observer import PerTensor, PerRow
2318
from torchao.float8.float8_utils import compute_error
2419
import torch
2520
import unittest
2621
import pytest
27-
import tempfile
2822
import copy
2923
import random
30-
31-
from unittest.mock import patch
24+
from functools import partial
25+
from typing import Tuple
26+
from contextlib import nullcontext
27+
import io
3228

3329

3430
random.seed(0)
@@ -56,6 +52,9 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
5652
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
5753
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
5854
@common_utils.parametrize("compile", [True, False])
55+
@common_utils.parametrize(
56+
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
57+
)
5958
# Inputs are (M,..), K, N
6059
@common_utils.parametrize(
6160
"sizes",
@@ -68,33 +67,142 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
6867
],
6968
)
7069
def test_fp8_linear_variants(
71-
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
70+
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
7271
):
73-
M, N, K = sizes
74-
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
75-
76-
mode_map = {
77-
"dynamic": float8_dynamic_activation_float8_weight,
78-
"weight-only": float8_weight_only,
79-
}
72+
raises = (
73+
isinstance(granularity, PerRow)
74+
and mode == "dynamic"
75+
and dtype != torch.bfloat16
76+
)
77+
context = (
78+
nullcontext()
79+
if not raises
80+
else pytest.raises(
81+
AssertionError,
82+
match="PerRow quantization only works for bfloat16 precision",
83+
)
84+
)
85+
with context:
86+
M, N, K = sizes
87+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
88+
89+
mode_map = {
90+
"dynamic": partial(
91+
float8_dynamic_activation_float8_weight, granularity=granularity
92+
),
93+
"weight-only": float8_weight_only,
94+
}
95+
96+
# Create a linear layer with bfloat16 dtype
97+
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
98+
99+
quantized_model = copy.deepcopy(model)
100+
factory = mode_map[mode]()
101+
quantize_(model, factory)
102+
103+
if compile:
104+
quantized_model = torch.compile(quantized_model, fullgraph=True)
105+
106+
output_original = model(input_tensor)
107+
output_quantized = quantized_model(input_tensor)
108+
109+
error = compute_error(output_original, output_quantized)
110+
assert (
111+
compute_error(output_original, output_quantized) > 20
112+
), f"Quantization error is too high got a SQNR of {error}"
113+
114+
def test_invalid_granularity(self):
115+
with pytest.raises(ValueError, match="Invalid granularity specification"):
116+
float8_dynamic_activation_float8_weight(granularity="invalid")
117+
118+
def test_mismatched_granularity(self):
119+
with pytest.raises(
120+
ValueError,
121+
match="Different granularities for activation and weight are not supported",
122+
):
123+
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
124+
125+
def test_unsupported_granularity(self):
126+
class UnsupportedGranularity:
127+
pass
128+
129+
with pytest.raises(ValueError, match="Invalid granularity types"):
130+
float8_dynamic_activation_float8_weight(
131+
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
132+
)
80133

81-
# Create a linear layer with bfloat16 dtype
82-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
134+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
135+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
136+
def test_per_row_with_float32(self):
137+
with pytest.raises(
138+
AssertionError,
139+
match="PerRow quantization only works for bfloat16 precision",
140+
):
141+
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
142+
quantize_(
143+
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
144+
)
83145

84-
quantized_model = copy.deepcopy(model)
85-
factory = mode_map[mode]()
146+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
147+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
148+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
149+
def test_serialization(self, mode: str):
150+
# Create and quantize the model
151+
model = ToyLinearModel(16, 32).to(device="cuda")
152+
if mode == "dynamic":
153+
factory = float8_dynamic_activation_float8_weight()
154+
else:
155+
factory = float8_weight_only()
86156
quantize_(model, factory)
87157

88-
if compile:
89-
quantized_model = torch.compile(quantized_model, fullgraph=True)
90-
91-
output_original = model(input_tensor)
92-
output_quantized = quantized_model(input_tensor)
93-
94-
error = compute_error(output_original, output_quantized)
95-
assert (
96-
compute_error(output_original, output_quantized) > 20
97-
), f"Quantization error is too high got a SQNR of {error}"
158+
# Save the state dict to an in-memory buffer
159+
buffer = io.BytesIO()
160+
torch.save(model.state_dict(), buffer)
161+
162+
# Reset the buffer position
163+
buffer.seek(0)
164+
165+
# Load the state dict from the buffer
166+
loaded_state_dict = torch.load(buffer)
167+
168+
# Create a new model and load the state dict
169+
with torch.device("meta"):
170+
new_model = ToyLinearModel(16, 32)
171+
new_model.load_state_dict(loaded_state_dict, assign=True)
172+
173+
# Compare the original and loaded models
174+
if mode == "weight-only":
175+
model_weight_1 = model.linear1.weight.layout_tensor.float8_data.to(
176+
torch.float32
177+
)
178+
new_model_weight_1 = new_model.linear1.weight.layout_tensor.float8_data.to(
179+
torch.float32
180+
)
181+
182+
model_weight_2 = model.linear2.weight.layout_tensor.float8_data.to(
183+
torch.float32
184+
)
185+
new_model_weight_2 = new_model.linear2.weight.layout_tensor.float8_data.to(
186+
torch.float32
187+
)
188+
189+
else:
190+
model_weight_1 = model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
191+
torch.float32
192+
)
193+
new_model_weight_1 = new_model.linear1.weight.original_weight_tensor.layout_tensor.float8_data.to(
194+
torch.float32
195+
)
196+
197+
model_weight_2 = model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
198+
torch.float32
199+
)
200+
new_model_weight_2 = new_model.linear2.weight.original_weight_tensor.layout_tensor.float8_data.to(
201+
torch.float32
202+
)
203+
204+
assert torch.allclose(model_weight_1, new_model_weight_1)
205+
assert torch.allclose(model_weight_2, new_model_weight_2)
98206

99207

100208
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
is_device,
2828
get_out_shape,
2929
)
30+
from torchao.float8.inference import (
31+
preprocess_data,
32+
Float8MMConfig,
33+
addmm_float8_unwrapped_inference,
34+
_is_rowwise_scaled
35+
)
3036
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3137
from dataclasses import dataclass
3238
from torchao.utils import (
@@ -1355,53 +1361,61 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):
13551361

13561362
return out.view(*act.shape[:-1], out_dim).to(act.dtype)
13571363

1358-
def _linear_fp_act_fp8_tensor_wise_weight_check(
1364+
def _linear_fp_act_fp8_weight_check(
13591365
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
13601366
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
13611367
bias: Optional[torch.Tensor],
13621368
) -> bool:
1363-
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
1369+
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
13641370
return (
13651371
isinstance(aqt, AffineQuantizedTensor) and
13661372
isinstance(aqt.layout_type, Float8LayoutType)
13671373
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
1368-
and aqt.shape == aqt.block_size
1374+
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
13691375
)
1370-
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
1376+
return check_aqt(input_tensor) and check_aqt(weight_tensor)
1377+
13711378

1379+
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
1380+
""" Ensures input tensor is correctly formated for _scaled_mm """
1381+
input_scale = input_scale.unsqueeze(-1)
1382+
1383+
if input_scale.dim() > 2:
1384+
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1385+
1386+
return input_scale
13721387

13731388
def _linear_fp_act_fp8_weight_impl(
13741389
input_tensor: AffineQuantizedTensor,
13751390
weight_tensor: AffineQuantizedTensor,
13761391
bias: Optional[torch.Tensor],
13771392
):
13781393
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
1379-
from torchao.float8.inference import (
1380-
preprocess_data,
1381-
Float8MMConfig,
1382-
addmm_float8_unwrapped_inference,
1383-
)
1384-
13851394
scaled_mm_config = weight_tensor.layout_type.mm_config
1386-
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
1395+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
13871396

1397+
# Weight tensor preprocessing
13881398
w_layout = weight_tensor.layout_tensor
1389-
w_data = weight_tensor.layout_tensor.float8_data
1390-
w_data = w_data.T if w_layout.transposed else w_data
1399+
assert not w_layout.transposed, "Weight tensor must be contiguous"
1400+
w_data = w_layout.float8_data
13911401
w_scale = w_layout.scale
1392-
w_scale = w_scale if w_layout.transposed else w_scale
1393-
1394-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
13951402

1403+
# Input tensor preprocessing
13961404
inpt_data = input_tensor.layout_tensor.float8_data
1397-
# Handle case where input tensor is more than 2D
1398-
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
13991405
input_scale = input_tensor.layout_tensor.scale
1400-
if input_scale.dim() > 2:
1401-
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1406+
# Handle case where input tensor is more than 2D
1407+
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
1408+
1409+
# Handle rowwise case
1410+
if _is_rowwise_scaled(weight_tensor):
1411+
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
1412+
w_scale = w_scale.unsqueeze(-1).T
1413+
input_scale = preprocess_scale(input_scale, input_tensor.shape)
14021414

1415+
# Preprocess data
14031416
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
14041417

1418+
# Perform the computation
14051419
return addmm_float8_unwrapped_inference(
14061420
inpt_data,
14071421
input_scale,
@@ -1458,7 +1472,7 @@ def _register_aqt_quantized_linear_dispatches():
14581472
for dispatch_condition, impl in [
14591473
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
14601474
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
1461-
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
1475+
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
14621476
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
14631477
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
14641478
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),

torchao/float8/inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
9797
use_fast_accum=use_fast_accum,
9898
)
9999
return output
100+
101+
102+
def _is_rowwise_scaled(x) -> bool:
103+
"""Checks if an AQT tensor is rowwise scaled
104+
Args:
105+
x: AffineQuantizedTensor tensor
106+
"""
107+
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)

torchao/quantization/observer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ class PerAxis(GranularityType):
5353
"""
5454
axis: int
5555

56+
@dataclass(frozen=True)
57+
class PerRow(GranularityType):
58+
"""
59+
Represents row-wise granularity in quantization.
60+
61+
This is a special case of per-axis quantization and is unique to Float8 matmuls
62+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
63+
is quantized with a block_size of (1, weight.shape[1]).
64+
"""
65+
pass
5666

5767
# borrowed from torch.ao.quantization.observer
5868
class _PartialWrapper:
@@ -104,6 +114,8 @@ def get_block_size(
104114
block_size = list(input_shape)
105115
block_size[granularity_type.axis] = 1
106116
return tuple(block_size)
117+
elif isinstance(granularity_type, PerRow):
118+
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
107119
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
108120

109121

0 commit comments

Comments
 (0)