Skip to content

Commit bf0a842

Browse files
committed
Add module swap quantization API from Quanty
**Summary:** This commit adds a module-swap-based PTQ API from Quanty, including: - Quantized linear and embedding modules - `IntQuantizer` to specify how to quantize weights and activations - `CodeBookQuantizer` as an alternative to IntQuantizer - Implementation of K-means to be used for codebook quantization - Range setting and data getter utility These new APIs will complement our existing `quantize_` API, which is primarily used for tensor-subclass-based quantization today (though it can also support module swaps). All APIs introduced in this commit are under prototype and highly subject to change. In particular, we plan to delete `quantize_module_swap` and `QuantizationRecipe`, and instead integrate this flow with the `quantize_` API by creating a new `AOBaseConfig`. All code is migrated from Quanty and written by @TiRune. **Test Plan:** python test/quantization/module_swap/test_*
1 parent 8c81863 commit bf0a842

19 files changed

+2172
-0
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ lint.ignore = ["E731"]
77
# Exclude third-party modules
88
exclude = [
99
"third_party/*",
10+
"prototype/quantization/module_swap/*",
1011
]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import copy
2+
import unittest
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from torchao.prototype.quantization.module_swap.algorithms import kmeans_codebook
8+
from torchao.prototype.quantization.module_swap import (
9+
CodeBookQuantizer,
10+
QuantizedLinear,
11+
)
12+
13+
14+
class SimpleTestNetwork(nn.Module):
15+
def __init__(self, weight_group_size: str | int = "per_channel") -> None:
16+
super().__init__()
17+
if weight_group_size == "per_channel":
18+
weight_group_size = 8
19+
assert isinstance(weight_group_size, int)
20+
weight_quantizer = CodeBookQuantizer(
21+
n_bits=2,
22+
features=16,
23+
codebook_dim=2,
24+
)
25+
26+
self.linear = QuantizedLinear(
27+
in_features=16,
28+
out_features=8,
29+
bias=False,
30+
weight_quantizer=weight_quantizer,
31+
activation_bits=8,
32+
input_quantization=False,
33+
output_quantization=False,
34+
weight_quantization=True,
35+
activation_quantization=False,
36+
)
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
return self.linear(x)
40+
41+
42+
class TestKmeansCodebook(unittest.TestCase):
43+
44+
@unittest.skip("No module named 'faiss'")
45+
def test_kmeans_codebook(self) -> None:
46+
model = SimpleTestNetwork()
47+
codebook_before = copy.deepcopy(model.linear.weight_quantizer.codebook)
48+
kmeans_codebook(model)
49+
assert not torch.allclose(
50+
codebook_before,
51+
model.linear.weight_quantizer.codebook,
52+
)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Tuple
2+
import unittest
3+
4+
import torch
5+
6+
from torchao.prototype.quantization.module_swap.data_getters import LLMPTQDataGetter
7+
8+
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM
9+
10+
11+
test_config = LlamaConfig(
12+
vocab_size=10,
13+
hidden_size=32,
14+
num_hidden_layers=2,
15+
num_attention_heads=2,
16+
intermediate_size=64,
17+
)
18+
19+
20+
def get_test_llama_model_data() -> Tuple[LlamaForCausalLM, torch.Tensor]:
21+
model = LlamaForCausalLM(test_config)
22+
input_ids = torch.randint(0, test_config.vocab_size, (1, 10))
23+
return model, input_ids
24+
25+
26+
class TestPTQDataGetter(unittest.TestCase):
27+
28+
@unittest.skip("TypeError: cannot unpack non-iterable NoneType object")
29+
def test_data_getter(self) -> None:
30+
model, data = get_test_llama_model_data()
31+
data_getter = LLMPTQDataGetter(model, data, 1)
32+
for name, module in model.named_modules():
33+
if isinstance(module, torch.nn.Linear):
34+
data = data_getter.pop(model, name)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from torchao.prototype.quantization.module_swap.module_swap import (
7+
QuantizationRecipe,
8+
quantize_module_swap,
9+
)
10+
11+
12+
class SimpleEmbeddingTestNetwork(nn.Module):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
self.embedding = nn.Embedding(10, 64)
16+
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return self.embedding(x)
19+
20+
21+
class TestEmbeddingSwap(unittest.TestCase):
22+
def test_embedding_swap(self) -> None:
23+
model = SimpleEmbeddingTestNetwork()
24+
recipe = QuantizationRecipe()
25+
recipe.embedding_bits = 4
26+
recipe.embedding_quantization = True
27+
model = quantize_module_swap(model, recipe)
28+
x = torch.randint(0, 10, (10, 64))
29+
model(x)
30+
assert model.embedding.weight_quantizer.num_bits == 4
31+
assert model.embedding.weight_quantizer.group_size == 32
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

0 commit comments

Comments
 (0)