Skip to content

migrate prototype codebook quant to configs #1858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 80 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
73ea701
Update
vkuzo Mar 7, 2025
4f2c69d
Update
vkuzo Mar 7, 2025
75fb698
Update
vkuzo Mar 7, 2025
5796efc
Update
vkuzo Mar 7, 2025
3b016e7
Update
vkuzo Mar 7, 2025
11bce08
Update
vkuzo Mar 7, 2025
3a91a06
Update
vkuzo Mar 7, 2025
91c350f
Update
vkuzo Mar 7, 2025
125546b
Update
vkuzo Mar 7, 2025
130a52e
Update
vkuzo Mar 7, 2025
70e0e53
Update
vkuzo Mar 7, 2025
049d750
Update
vkuzo Mar 7, 2025
7b8fb8d
Update
vkuzo Mar 7, 2025
74965d3
Update
vkuzo Mar 7, 2025
1ea0e2f
Update
vkuzo Mar 7, 2025
03ea2e4
Update
vkuzo Mar 7, 2025
9d7210b
Update
vkuzo Mar 7, 2025
cf3ad33
Update
vkuzo Mar 7, 2025
19ac99d
Update
vkuzo Mar 7, 2025
5deed22
Update
vkuzo Mar 7, 2025
2a1f7b2
Update
vkuzo Mar 7, 2025
5fa0e27
Update
vkuzo Mar 7, 2025
160cc29
Update
vkuzo Mar 7, 2025
7e40b15
Update
vkuzo Mar 7, 2025
b0a32c3
Update
vkuzo Mar 7, 2025
0ecb02d
Update
vkuzo Mar 7, 2025
7cb810c
Update
vkuzo Mar 7, 2025
eb567cd
Update
vkuzo Mar 7, 2025
cd97b30
Update
vkuzo Mar 7, 2025
de38b6e
Update
vkuzo Mar 7, 2025
bfba1d9
Update
vkuzo Mar 7, 2025
9ac2334
Update
vkuzo Mar 7, 2025
c042458
Update
vkuzo Mar 7, 2025
6f3d127
Update
vkuzo Mar 8, 2025
cda5d18
Update
vkuzo Mar 8, 2025
96d74a3
Update
vkuzo Mar 8, 2025
c83c029
Update
vkuzo Mar 8, 2025
6f1c92d
Update
vkuzo Mar 8, 2025
95be23e
Update
vkuzo Mar 8, 2025
0776629
Update
vkuzo Mar 8, 2025
52a31a8
Update
vkuzo Mar 8, 2025
fdb292e
Update
vkuzo Mar 8, 2025
706ff1f
Update
vkuzo Mar 8, 2025
ac2314e
Update
vkuzo Mar 8, 2025
8002c39
Update
vkuzo Mar 8, 2025
a4dfaa1
Update
vkuzo Mar 8, 2025
ecdab3b
Update
vkuzo Mar 8, 2025
7679470
Update
vkuzo Mar 8, 2025
0506d32
Update
vkuzo Mar 8, 2025
237a72a
Update
vkuzo Mar 8, 2025
7183e83
Update
vkuzo Mar 8, 2025
10d0dff
Update
vkuzo Mar 8, 2025
be63b3c
Update
vkuzo Mar 8, 2025
fc44c79
Update
vkuzo Mar 8, 2025
2ec7827
Update
vkuzo Mar 8, 2025
3f10bc5
Update
vkuzo Mar 8, 2025
fa8c0f1
Update
vkuzo Mar 8, 2025
e8ee9a1
Update
vkuzo Mar 8, 2025
025bd67
Update
vkuzo Mar 8, 2025
50a7f9f
Update
vkuzo Mar 8, 2025
8c8d7e4
Update
vkuzo Mar 8, 2025
5b62372
Update
vkuzo Mar 8, 2025
89f6763
Update
vkuzo Mar 8, 2025
34ce5f4
Update
vkuzo Mar 10, 2025
c038451
Update
vkuzo Mar 10, 2025
c3815dc
Update
vkuzo Mar 10, 2025
9280bc9
Update
vkuzo Mar 10, 2025
4f195b6
Update
vkuzo Mar 10, 2025
0d388c8
Update
vkuzo Mar 10, 2025
42f323f
Update
vkuzo Mar 10, 2025
db3e3d3
Update
vkuzo Mar 12, 2025
5f742de
Update
vkuzo Mar 12, 2025
6500a26
Update
vkuzo Mar 12, 2025
8d16b60
Update
vkuzo Mar 12, 2025
bdf1aea
Update
vkuzo Mar 12, 2025
c2577a2
Update
vkuzo Mar 12, 2025
902f8fa
Update
vkuzo Mar 12, 2025
82b6281
Update
vkuzo Mar 12, 2025
34bfd94
Update
vkuzo Mar 12, 2025
b63f186
Update
vkuzo Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/prototype/test_codebook_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from torchao.prototype.quantization.codebook import (
CodebookQuantizedTensor,
choose_qparams_codebook,
codebook_weight_only,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error


Expand Down Expand Up @@ -62,6 +64,11 @@ def test_codebook_quantized_tensor_from_float2(self):
sqnr = compute_error(dequant, self.input)
self.assertGreater(sqnr, 30)

def test_quantize_api(self):
m = torch.nn.Sequential(torch.nn.Linear(64, 64))
quantize_(m, codebook_weight_only())
assert type(m[0].weight) == CodebookQuantizedTensor


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch

from torchao.core.config import AOBaseConfig
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxTensor
from torchao.prototype.quantization.codebook.codebook_ops import (
choose_qparams_codebook,
dequantize_codebook,
quantize_codebook,
)
from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TorchAOBaseTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -254,10 +258,21 @@ def function_requires_grad_(tensor, *args, **kwargs):
return tensor.requires_grad_(*args, **kwargs)


def codebook_weight_only(
dtype=torch.uint4,
block_size: Tuple[int, int] = (1, 1),
scale_block_size: int = None,
@dataclass
class CodebookWeightOnlyConfig(AOBaseConfig):
dtype: torch.dtype = torch.uint4
block_size: Tuple[int, int] = (1, 1)
scale_block_size: int = None


# for bc
codebook_weight_only = CodebookWeightOnlyConfig


@register_quantize_module_handler(CodebookWeightOnlyConfig)
def _codebook_weight_only_transform(
module: torch.nn.Module,
config: CodebookWeightOnlyConfig,
):
"""
Applies codebook weight-only quantization to linear layers.
Expand All @@ -269,20 +284,20 @@ def codebook_weight_only(
Returns:
Callable for quantization transformation.
"""

def apply_codebook_quantization(weight, scale_block_size):
if weight.numel() > 2**27:
return weight # k_means is too numerically unstable
if scale_block_size is None:
scale_block_size = weight.shape[1]
quantized = CodebookQuantizedTensor.from_float(
weight,
block_size=block_size,
code_dtype=dtype,
scale_block_size=scale_block_size,
)
return quantized

return _get_linear_subclass_inserter(
apply_codebook_quantization, scale_block_size=scale_block_size
dtype = config.dtype
block_size = config.block_size
scale_block_size = config.scale_block_size
weight = module.weight

if weight.numel() > 2**27:
return module # k_means is too numerically unstable
if scale_block_size is None:
scale_block_size = weight.shape[1]
quantized_weight = CodebookQuantizedTensor.from_float(
weight,
block_size=block_size,
code_dtype=dtype,
scale_block_size=scale_block_size,
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
Loading