Skip to content

Commit 06a5628

Browse files
committed
[Float8Quant] Add rowwise scaling option to float8 dyanmic quant
stack-info: PR: #819, branch: drisspg/stack/11
1 parent a246d87 commit 06a5628

File tree

6 files changed

+90
-35
lines changed

6 files changed

+90
-35
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ include = [
88
"torchao/dtypes/nf4tensor.py",
99
"test/dtypes/test_nf4.py",
1010
"torchao/float8/float8_tensor.py",
11+
"test/dtypes/test_affine_quantized_float.py"
1112
]

test/dtypes/test_affine_quantized_float.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,28 @@
11
from torchao.utils import (
22
TORCH_VERSION_AT_LEAST_2_5,
3-
unwrap_tensor_subclass,
43
)
54
import pytest
65

76
if not TORCH_VERSION_AT_LEAST_2_5:
87
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
98

10-
from numpy import full
11-
from torch.testing._internal.common_utils import (
12-
run_tests,
13-
)
149
from torch._inductor.test_case import TestCase as InductorTestCase
1510
from torch.testing._internal import common_utils
16-
from torch._dynamo.testing import CompileCounterWithBackend
1711

1812
from torchao.quantization import (
1913
quantize_,
2014
float8_weight_only,
2115
float8_dynamic_activation_float8_weight,
2216
)
17+
from torchao.quantization.observer import PerTensor, RowWise
2318
from torchao.float8.float8_utils import compute_error
2419
import torch
2520
import unittest
2621
import pytest
27-
import tempfile
2822
import copy
2923
import random
30-
31-
from unittest.mock import patch
24+
from functools import partial
25+
from typing import Tuple
3226

3327

3428
random.seed(0)
@@ -56,6 +50,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
5650
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
5751
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
5852
@common_utils.parametrize("compile", [True, False])
53+
@common_utils.parametrize("granularity", [PerTensor, RowWise])
5954
# Inputs are (M,..), K, N
6055
@common_utils.parametrize(
6156
"sizes",
@@ -68,13 +63,20 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
6863
],
6964
)
7065
def test_fp8_linear_variants(
71-
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
66+
self, dtype: torch.dtype, mode: str, compile: bool, sizes: Tuple, granularity
7267
):
68+
if granularity == RowWise and mode == "dynamic":
69+
pytest.skip(
70+
"RowWise quantization only works for bfloat16 precision input weight and activation for now"
71+
)
72+
7373
M, N, K = sizes
7474
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
7575

7676
mode_map = {
77-
"dynamic": float8_dynamic_activation_float8_weight,
77+
"dynamic": partial(
78+
float8_dynamic_activation_float8_weight, granularity=granularity
79+
),
7880
"weight-only": float8_weight_only,
7981
}
8082

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
is_device,
2727
get_out_shape,
2828
)
29+
from torchao.float8.inference import (
30+
preprocess_data,
31+
Float8MMConfig,
32+
addmm_float8_unwrapped_inference,
33+
_is_rowwise_scaled
34+
)
2935
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
3036
from dataclasses import dataclass
3137
from torchao.utils import (
@@ -1161,53 +1167,62 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias):
11611167

11621168
return out.view(*act.shape[:-1], out_dim).to(act.dtype)
11631169

1164-
def _linear_fp_act_fp8_tensor_wise_weight_check(
1170+
def _linear_fp_act_fp8_weight_check(
11651171
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
11661172
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
11671173
bias: Optional[torch.Tensor],
11681174
) -> bool:
1169-
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
1175+
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
11701176
return (
11711177
isinstance(aqt, AffineQuantizedTensor) and
11721178
isinstance(aqt.layout_type, Float8LayoutType)
11731179
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
1174-
and aqt.shape == aqt.block_size
1180+
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
11751181
)
1176-
return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor)
1182+
return check_aqt(input_tensor) and check_aqt(weight_tensor)
1183+
11771184

1185+
def preprocess_input_tensor(inpt_data: torch.Tensor, input_scale: torch.Tensor, input_shape: Tuple[int]):
1186+
""" Ensures input tensor is correctly formated for _scaled_mm """
1187+
if input_scale.size(0) != 1:
1188+
input_scale = input_scale.unsqueeze(-1)
1189+
1190+
if input_scale.dim() > 2:
1191+
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1192+
1193+
return inpt_data, input_scale
11781194

11791195
def _linear_fp_act_fp8_weight_impl(
11801196
input_tensor: AffineQuantizedTensor,
11811197
weight_tensor: AffineQuantizedTensor,
11821198
bias: Optional[torch.Tensor],
11831199
):
11841200
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
1185-
from torchao.float8.inference import (
1186-
preprocess_data,
1187-
Float8MMConfig,
1188-
addmm_float8_unwrapped_inference,
1189-
)
1190-
11911201
scaled_mm_config = weight_tensor.layout_type.mm_config
1192-
scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig()
1202+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
11931203

1204+
# Weight tensor preprocessing
11941205
w_layout = weight_tensor.layout_tensor
1195-
w_data = weight_tensor.layout_tensor.float8_data
1196-
w_data = w_data.T if w_layout.transposed else w_data
1206+
assert not w_layout.transposed, "Weight tensor must be contiguous"
1207+
w_data = w_layout.float8_data
11971208
w_scale = w_layout.scale
1198-
w_scale = w_scale if w_layout.transposed else w_scale
1199-
1200-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
12011209

1210+
# Input tensor preprocessing
12021211
inpt_data = input_tensor.layout_tensor.float8_data
1203-
# Handle case where input tensor is more than 2D
1204-
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
12051212
input_scale = input_tensor.layout_tensor.scale
1206-
if input_scale.dim() > 2:
1207-
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
1213+
# Handle case where input tensor is more than 2D
1214+
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
1215+
1216+
# Handle rowwise case
1217+
if _is_rowwise_scaled(weight_tensor):
1218+
assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size"
1219+
w_scale = w_scale.unsqueeze(-1).T
1220+
inpt_data, input_scale = preprocess_input_tensor(inpt_data, input_scale, input_tensor.shape)
12081221

1222+
# Preprocess data
12091223
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
12101224

1225+
# Perform the computation
12111226
return addmm_float8_unwrapped_inference(
12121227
inpt_data,
12131228
input_scale,
@@ -1223,7 +1238,7 @@ def _register_aqt_quantized_linear_dispatches():
12231238
for dispatch_condition, impl in [
12241239
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
12251240
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
1226-
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
1241+
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
12271242
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
12281243
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
12291244
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),

torchao/float8/inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,11 @@ def addmm_float8_unwrapped_inference(
9797
use_fast_accum=use_fast_accum,
9898
)
9999
return output
100+
101+
102+
def _is_rowwise_scaled(x) -> bool:
103+
"""Checks if an AQT tensor is rowwise scaled
104+
Args:
105+
x: AffineQuantizedTensor tensor
106+
"""
107+
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)

torchao/quantization/observer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ class PerAxis(GranularityType):
5252
"""
5353
axis: int
5454

55+
@dataclass(frozen=True)
56+
class RowWise(GranularityType):
57+
"""
58+
Represents row-wise granularity in quantization.
59+
60+
This is a special case of per-axis quantization and is unique to Float8 matmuls
61+
where the input is quantized with a block_size of (1, input.shape[1]). And the weight
62+
is quantized with a block_size of (weight.shape[0], 1).
63+
"""
64+
pass
65+
5566
# borrowed from torch.ao.quantization.observer
5667
class _PartialWrapper:
5768
def __init__(self, p):

torchao/quantization/quant_api.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torchao
2020
import torch.nn as nn
2121
import torch.nn.functional as F
22-
from typing import Any, Callable, Union, Dict, Optional
22+
from typing import Any, Callable, Union, Dict, Optional, Literal
2323
import types
2424

2525
from torchao.dtypes.uintx.Uintx import UintxLayoutType
@@ -60,6 +60,8 @@
6060
from .autoquant import autoquant, AutoQuantizableLinearWeight
6161
from torchao.float8.inference import Float8MMConfig
6262

63+
from torchao.quantization.observer import GranularityType, PerTensor, PerAxis, RowWise
64+
6365
logger = logging.getLogger(__name__)
6466

6567
__all__ = [
@@ -550,6 +552,7 @@ def apply_float8wo_quant(weight):
550552
def float8_dynamic_activation_float8_weight(
551553
activation_dtype: torch.dtype = torch.float8_e4m3fn,
552554
weight_dtype: torch.dtype = torch.float8_e4m3fn,
555+
granularity: Literal[PerTensor, RowWise] = PerTensor,
553556
mm_config: Optional[Float8MMConfig] = None
554557
):
555558
"""
@@ -566,20 +569,35 @@ def float8_dynamic_activation_float8_weight(
566569
if mm_config is None:
567570
mm_config = Float8MMConfig(use_fast_accum=True)
568571

572+
def get_block_size(x: torch.Tensor, granularity: Literal[PerTensor, RowWise]):
573+
if granularity == PerTensor:
574+
return x.shape
575+
elif granularity == RowWise:
576+
return (1,) * (x.dim() - 1) + (x.shape[-1],)
577+
else:
578+
raise ValueError(f"Unsupported granularity: {granularity}")
579+
569580
#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
570581
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
582+
if granularity == RowWise:
583+
assert weight.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"
584+
585+
block_size = get_block_size(weight, granularity)
571586
quantized_weight = to_affine_quantized_floatx(
572587
input_float=weight,
573-
block_size=weight.shape,
588+
block_size=block_size,
574589
target_dtype=weight_dtype,
575590
scale_dtype=torch.float32,
576591
layout_type=Float8LayoutType(mm_config=mm_config),
577592
)
578593

579594
def input_quant_func(x: torch.Tensor):
595+
if granularity == RowWise:
596+
assert x.dtype == torch.bfloat16, "RowWise quantization only works for bfloat16 precision input weight and activation"
597+
block_size = get_block_size(x, granularity)
580598
activation = to_affine_quantized_floatx(
581599
input_float=x,
582-
block_size=x.shape,
600+
block_size=block_size,
583601
target_dtype=activation_dtype,
584602
scale_dtype=torch.float32,
585603
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight

0 commit comments

Comments
 (0)