Skip to content

Commit afeac2f

Browse files
authored
migrate prototype codebook quant to configs (#1858)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent ac411bf commit afeac2f

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

test/prototype/test_codebook_quant.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from torchao.prototype.quantization.codebook import (
66
CodebookQuantizedTensor,
77
choose_qparams_codebook,
8+
codebook_weight_only,
89
)
10+
from torchao.quantization import quantize_
911
from torchao.quantization.utils import compute_error
1012

1113

@@ -62,6 +64,11 @@ def test_codebook_quantized_tensor_from_float2(self):
6264
sqnr = compute_error(dequant, self.input)
6365
self.assertGreater(sqnr, 30)
6466

67+
def test_quantize_api(self):
68+
m = torch.nn.Sequential(torch.nn.Linear(64, 64))
69+
quantize_(m, codebook_weight_only())
70+
assert type(m[0].weight) == CodebookQuantizedTensor
71+
6572

6673
if __name__ == "__main__":
6774
unittest.main()

torchao/prototype/quantization/codebook/codebook_quantized_tensor.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
from dataclasses import dataclass
12
from typing import Optional, Tuple
23

34
import torch
45

6+
from torchao.core.config import AOBaseConfig
57
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxTensor
68
from torchao.prototype.quantization.codebook.codebook_ops import (
79
choose_qparams_codebook,
810
dequantize_codebook,
911
quantize_codebook,
1012
)
11-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
13+
from torchao.quantization.transform_module import (
14+
register_quantize_module_handler,
15+
)
1216
from torchao.utils import TorchAOBaseTensor
1317

1418
aten = torch.ops.aten
@@ -254,10 +258,21 @@ def function_requires_grad_(tensor, *args, **kwargs):
254258
return tensor.requires_grad_(*args, **kwargs)
255259

256260

257-
def codebook_weight_only(
258-
dtype=torch.uint4,
259-
block_size: Tuple[int, int] = (1, 1),
260-
scale_block_size: int = None,
261+
@dataclass
262+
class CodebookWeightOnlyConfig(AOBaseConfig):
263+
dtype: torch.dtype = torch.uint4
264+
block_size: Tuple[int, int] = (1, 1)
265+
scale_block_size: int = None
266+
267+
268+
# for bc
269+
codebook_weight_only = CodebookWeightOnlyConfig
270+
271+
272+
@register_quantize_module_handler(CodebookWeightOnlyConfig)
273+
def _codebook_weight_only_transform(
274+
module: torch.nn.Module,
275+
config: CodebookWeightOnlyConfig,
261276
):
262277
"""
263278
Applies codebook weight-only quantization to linear layers.
@@ -269,20 +284,20 @@ def codebook_weight_only(
269284
Returns:
270285
Callable for quantization transformation.
271286
"""
272-
273-
def apply_codebook_quantization(weight, scale_block_size):
274-
if weight.numel() > 2**27:
275-
return weight # k_means is too numerically unstable
276-
if scale_block_size is None:
277-
scale_block_size = weight.shape[1]
278-
quantized = CodebookQuantizedTensor.from_float(
279-
weight,
280-
block_size=block_size,
281-
code_dtype=dtype,
282-
scale_block_size=scale_block_size,
283-
)
284-
return quantized
285-
286-
return _get_linear_subclass_inserter(
287-
apply_codebook_quantization, scale_block_size=scale_block_size
287+
dtype = config.dtype
288+
block_size = config.block_size
289+
scale_block_size = config.scale_block_size
290+
weight = module.weight
291+
292+
if weight.numel() > 2**27:
293+
return module # k_means is too numerically unstable
294+
if scale_block_size is None:
295+
scale_block_size = weight.shape[1]
296+
quantized_weight = CodebookQuantizedTensor.from_float(
297+
weight,
298+
block_size=block_size,
299+
code_dtype=dtype,
300+
scale_block_size=scale_block_size,
288301
)
302+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
303+
return module

0 commit comments

Comments
 (0)