Skip to content

Commit 6dd82d8

Browse files
authored
Add static quantization as an example for calibration flow (#487)
Summary: So far quantization flow API that we provided (`quantize_`) does not require calibration (calibrate a model with sample data), this PR added a static quantization example that serves as an example for calibration flow * 1. first prepare the model for calibration * 2. calibrate the prepared model with sample data * 3. convert the calibrated model to quantized model Test Plan: python torchao/prototype/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent e9e6671 commit 6dd82d8

File tree

6 files changed

+202
-13
lines changed

6 files changed

+202
-13
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ class TestAffineQuantized(TestCase):
1414
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1515
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1616
def test_tensor_core_layout_transpose(self):
17-
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
17+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
18+
t = l.weight
1819
shape = t.shape
1920
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
20-
aqt = apply_int4_weight_only_quant(t)
21+
ql = apply_int4_weight_only_quant(l)
22+
aqt = ql.weight
2123
aqt_shape = aqt.shape
2224
self.assertEqual(aqt_shape, shape)
2325

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .affine_quantized_tensor import (
55
AffineQuantizedTensor,
66
to_affine_quantized,
7+
to_affine_quantized_static,
78
LayoutType,
89
PlainLayoutType,
910
TensorCoreTiledLayoutType,
@@ -15,6 +16,7 @@
1516
"UInt4Tensor"
1617
"AffineQuantizedTensor",
1718
"to_affine_quantized",
19+
"to_affine_quantized_static",
1820
"LayoutType",
1921
"PlainLayoutType",
2022
"TensorCoreTiledLayoutType",

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def from_float(
232232

233233
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
234234
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
235+
235236
int_data = layout_type.post_process(int_data)
236237

237238
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
@@ -246,8 +247,40 @@ def from_float(
246247
dtype=input_float.dtype
247248
)
248249

250+
@classmethod
251+
def from_float_static(
252+
cls,
253+
input_float: torch.Tensor,
254+
scale: torch.Tensor,
255+
zero_point: torch.Tensor,
256+
block_size: Tuple[int, ...],
257+
target_dtype: torch.dtype,
258+
quant_min: Optional[int] = None,
259+
quant_max: Optional[int] = None,
260+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
261+
layout_type: LayoutType = PlainLayoutType(),
262+
):
263+
original_shape = input_float.shape
264+
input_float = layout_type.pre_process(input_float)
265+
266+
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
267+
268+
int_data = layout_type.post_process(int_data)
269+
270+
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
271+
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
272+
return cls(
273+
layout_tensor,
274+
block_size,
275+
original_shape,
276+
quant_min,
277+
quant_max,
278+
zero_point_domain,
279+
dtype=input_float.dtype,
280+
)
281+
249282
@property
250-
def layout_type(self) -> str:
283+
def layout_type(self) -> LayoutType:
251284
return self.layout_tensor.layout_type
252285

253286
@classmethod
@@ -809,3 +842,4 @@ def t(func, *args, **kwargs):
809842
return return_and_correct_aliasing(func, args, kwargs, new)
810843

811844
to_affine_quantized = AffineQuantizedTensor.from_float
845+
to_affine_quantized_static = AffineQuantizedTensor.from_float_static

torchao/prototype/quant_llm/quant_llm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones
88
from torchao.ops import quant_llm_linear
99
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
10+
from torchao.quantization.quant_api import _get_linear_subclass_inserter
1011

1112

1213
_ONES_TABLE = [_n_ones(i) for i in range(8)]
@@ -456,8 +457,8 @@ def apply_quant_llm(weight: Tensor) -> Tensor:
456457
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
457458
return weight
458459
return QuantLlmLinearWeight.from_float(weight, ebits, mbits)
459-
return apply_quant_llm
460+
return _get_linear_subclass_inserter(apply_quant_llm)
460461

461462

462463
def fp6_llm_weight_only():
463-
return quant_llm_fpx_weight_only(3, 2)
464+
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))

torchao/quantization/quant_api.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,12 @@ def insert_subclass(lin):
259259

260260
return insert_subclass
261261

262-
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
262+
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
263263
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
264264
265265
Args:
266266
model (torch.nn.Module): input model
267-
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance)
267+
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
268268
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
269269
the weight of the module
270270
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
@@ -300,19 +300,24 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Ten
300300
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
301301
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
302302
303+
def apply_weight_quant_to_linear(linear):
304+
linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
305+
return linear
306+
303307
# apply to modules under block0 submodule
304308
def filter_fn(module: nn.Module, fqn: str) -> bool:
305309
return isinstance(module, nn.Linear)
306310
307311
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
308-
quantize_(m, apply_weight_quant, filter_fn)
312+
quantize_(m, apply_weight_quant_to_linear, filter_fn)
309313
310314
"""
311315
if set_inductor_config:
312316
torchao.quantization.utils.recommended_inductor_config_setter()
317+
313318
_replace_with_custom_fn_if_matches_filter(
314319
model,
315-
_get_linear_subclass_inserter(apply_tensor_subclass),
320+
apply_tensor_subclass,
316321
_is_linear if filter_fn is None else filter_fn,
317322
)
318323

@@ -356,7 +361,7 @@ def get_per_token_block_size(x):
356361
weight = to_linear_act_quantized(weight, input_quant_func)
357362
return weight
358363

359-
return apply_int8_dynamic_activation_int4_weight_quant
364+
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)
360365

361366

362367
def int4_weight_only(group_size=128, inner_k_tiles=8):
@@ -394,7 +399,7 @@ def apply_int4_weight_only_quant(weight):
394399
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
395400
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
396401

397-
return apply_int4_weight_only_quant
402+
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
398403

399404

400405
def int8_weight_only():
@@ -412,7 +417,7 @@ def apply_int8wo_quant(weight):
412417
block_size = (1, weight.shape[1])
413418
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
414419

415-
return apply_int8wo_quant
420+
return _get_linear_subclass_inserter(apply_int8wo_quant)
416421

417422
def int8_dynamic_activation_int8_weight():
418423
"""
@@ -454,4 +459,4 @@ def get_per_token_block_size(x):
454459
weight = to_linear_act_quantized(weight, input_quant_func)
455460
return weight
456461

457-
return apply_int8_dynamic_activation_int8_weight_quant
462+
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Demo for static quantization flow
3+
"""
4+
import torch
5+
import copy
6+
7+
# TODO: use the generalized observer for affine qunatization in the future
8+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
9+
import torch.nn.functional as F
10+
from torch import Tensor
11+
from torchao.dtypes import to_affine_quantized_static
12+
from torchao.quantization.utils import compute_error
13+
from torchao.quantization import quantize_
14+
from torchao.quantization.subclass import to_linear_act_quantized
15+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
16+
17+
18+
19+
class ObservedLinear(torch.nn.Linear):
20+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
21+
super().__init__(in_features, out_features, bias, device, dtype)
22+
self.act_obs = act_obs
23+
self.weight_obs = weight_obs
24+
25+
def forward(self, input: Tensor):
26+
observed_input = self.act_obs(input)
27+
observed_weight = self.weight_obs(self.weight)
28+
return F.linear(observed_input, observed_weight, self.bias)
29+
30+
@classmethod
31+
def from_float(cls, float_linear, act_obs, weight_obs):
32+
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
33+
observed_linear.weight = float_linear.weight
34+
observed_linear.bias = float_linear.bias
35+
return observed_linear
36+
37+
def insert_observers_(model, act_obs, weight_obs):
38+
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
39+
replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs)
40+
act_obs = copy.deepcopy(act_obs)
41+
weight_obs = copy.deepcopy(weight_obs)
42+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
43+
44+
# converting observed linear module to linear module with quantzied weights (and quantized activations)
45+
# with tensor subclasses
46+
def apply_static_quant(observed_linear):
47+
target_dtype = torch.uint8
48+
49+
# weight quantization
50+
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
51+
def weight_quant_func(weight):
52+
block_size = (1, weight.shape[1])
53+
return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
54+
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
55+
linear.weight = observed_linear.weight
56+
linear.bias = observed_linear.bias
57+
58+
linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)
59+
60+
# activation quantization
61+
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
62+
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype)
63+
linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False)
64+
65+
return linear
66+
67+
68+
# alternative for converting observed linear module to quantized linear module
69+
class QuantizedLinear(torch.nn.Module):
70+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
71+
super().__init__()
72+
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
73+
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
74+
assert weight.dim() == 2
75+
block_size = (1, weight.shape[1])
76+
target_dtype = torch.uint8
77+
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
78+
self.bias = bias
79+
80+
def forward(self, input: Tensor):
81+
block_size = input.shape
82+
target_dtype = torch.uint8
83+
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
84+
return F.linear(qinput, self.qweight, self.bias)
85+
86+
@classmethod
87+
def from_observed(cls, observed_linear):
88+
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias)
89+
return quantized_linear
90+
91+
def apply_static_quant2(observed_linear):
92+
return QuantizedLinear.from_observed(observed_linear)
93+
94+
class ToyLinearModel(torch.nn.Module):
95+
def __init__(self, m=64, n=32, k=64):
96+
super().__init__()
97+
self.linear1 = torch.nn.Linear(m, n, bias=False)
98+
self.linear2 = torch.nn.Linear(n, k, bias=False)
99+
100+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
101+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
102+
103+
def forward(self, x):
104+
x = self.linear1(x)
105+
x = self.linear2(x)
106+
return x
107+
108+
dtype = torch.bfloat16
109+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
110+
m_bf16 = copy.deepcopy(m)
111+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
112+
113+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
114+
115+
# TODO: use the generalized observer for affine qunatization in the future
116+
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
117+
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")
118+
119+
before_quant = m(*example_inputs)
120+
121+
insert_observers_(m, act_obs, weight_obs)
122+
# calibrating / training
123+
for _ in range(10):
124+
m(*example_inputs)
125+
126+
after_obs = m(*example_inputs)
127+
128+
m2 = copy.deepcopy(m)
129+
130+
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
131+
132+
# quantized linear represented as an nn.Linear with modified tensor subclass weights
133+
# for both activation and weight quantization
134+
quantize_(m, apply_static_quant, is_observed_linear)
135+
print("quantized model (applying tensor subclass to weight):", m)
136+
after_quant = m(*example_inputs)
137+
assert compute_error(before_quant, after_quant) > 30
138+
print("test passed")
139+
140+
# quantized linear as a standalone module
141+
quantize_(m2, apply_static_quant2, is_observed_linear)
142+
print("quantized model (quantized module):", m2)
143+
after_quant = m2(*example_inputs)
144+
assert compute_error(before_quant, after_quant) > 30
145+
print("test passed")

0 commit comments

Comments
 (0)