1
+ from dataclasses import dataclass
1
2
from typing import Optional , Tuple
2
3
3
4
import torch
4
5
6
+ from torchao .core .config import AOBaseConfig
5
7
from torchao .dtypes .uintx .uintx_layout import _DTYPE_TO_BIT_WIDTH , UintxTensor
6
8
from torchao .prototype .quantization .codebook .codebook_ops import (
7
9
choose_qparams_codebook ,
8
10
dequantize_codebook ,
9
11
quantize_codebook ,
10
12
)
11
- from torchao .quantization .quant_api import _get_linear_subclass_inserter
13
+ from torchao .quantization .transform_module import (
14
+ register_quantize_module_handler ,
15
+ )
12
16
from torchao .utils import TorchAOBaseTensor
13
17
14
18
aten = torch .ops .aten
@@ -254,10 +258,21 @@ def function_requires_grad_(tensor, *args, **kwargs):
254
258
return tensor .requires_grad_ (* args , ** kwargs )
255
259
256
260
257
- def codebook_weight_only (
258
- dtype = torch .uint4 ,
259
- block_size : Tuple [int , int ] = (1 , 1 ),
260
- scale_block_size : int = None ,
261
+ @dataclass
262
+ class CodebookWeightOnlyConfig (AOBaseConfig ):
263
+ dtype : torch .dtype = torch .uint4
264
+ block_size : Tuple [int , int ] = (1 , 1 )
265
+ scale_block_size : int = None
266
+
267
+
268
+ # for bc
269
+ codebook_weight_only = CodebookWeightOnlyConfig
270
+
271
+
272
+ @register_quantize_module_handler (CodebookWeightOnlyConfig )
273
+ def _codebook_weight_only_transform (
274
+ module : torch .nn .Module ,
275
+ config : CodebookWeightOnlyConfig ,
261
276
):
262
277
"""
263
278
Applies codebook weight-only quantization to linear layers.
@@ -269,20 +284,20 @@ def codebook_weight_only(
269
284
Returns:
270
285
Callable for quantization transformation.
271
286
"""
272
-
273
- def apply_codebook_quantization (weight , scale_block_size ):
274
- if weight .numel () > 2 ** 27 :
275
- return weight # k_means is too numerically unstable
276
- if scale_block_size is None :
277
- scale_block_size = weight .shape [1 ]
278
- quantized = CodebookQuantizedTensor .from_float (
279
- weight ,
280
- block_size = block_size ,
281
- code_dtype = dtype ,
282
- scale_block_size = scale_block_size ,
283
- )
284
- return quantized
285
-
286
- return _get_linear_subclass_inserter (
287
- apply_codebook_quantization , scale_block_size = scale_block_size
287
+ dtype = config .dtype
288
+ block_size = config .block_size
289
+ scale_block_size = config .scale_block_size
290
+ weight = module .weight
291
+
292
+ if weight .numel () > 2 ** 27 :
293
+ return module # k_means is too numerically unstable
294
+ if scale_block_size is None :
295
+ scale_block_size = weight .shape [1 ]
296
+ quantized_weight = CodebookQuantizedTensor .from_float (
297
+ weight ,
298
+ block_size = block_size ,
299
+ code_dtype = dtype ,
300
+ scale_block_size = scale_block_size ,
288
301
)
302
+ module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
303
+ return module
0 commit comments