Skip to content

Commit 38b1361

Browse files
committed
Add module swap quantization API from Quanty
**Summary:** This commit adds a module-swap-based PTQ API from Quanty, including: - Quantized linear and embedding modules - `IntQuantizer` to specify how to quantize weights and activations - `CodeBookQuantizer` as an alternative to IntQuantizer - Implementation of K-means to be used for codebook quantization - Range setting and data getter utility These new APIs will complement our existing `quantize_` API, which is primarily used for tensor-subclass-based quantization today (though it can also support module swaps). All APIs introduced in this commit are under prototype and highly subject to change. In particular, we plan to delete `quantize_module_swap` and `QuantizationRecipe`, and instead integrate this flow with the `quantize_` API by creating a new `AOBaseConfig`. All code is migrated from Quanty and written by @TiRune. **Test Plan:** python test/quantization/module_swap/test_*
1 parent 8c81863 commit 38b1361

18 files changed

+2167
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import copy
2+
import unittest
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from torchao.prototype.quantization.module_swap import (
8+
CodeBookQuantizer,
9+
QuantizedLinear,
10+
)
11+
from torchao.prototype.quantization.module_swap.algorithms import kmeans_codebook
12+
13+
14+
class SimpleTestNetwork(nn.Module):
15+
def __init__(self, weight_group_size: str | int = "per_channel") -> None:
16+
super().__init__()
17+
if weight_group_size == "per_channel":
18+
weight_group_size = 8
19+
assert isinstance(weight_group_size, int)
20+
weight_quantizer = CodeBookQuantizer(
21+
n_bits=2,
22+
features=16,
23+
codebook_dim=2,
24+
)
25+
26+
self.linear = QuantizedLinear(
27+
in_features=16,
28+
out_features=8,
29+
bias=False,
30+
weight_quantizer=weight_quantizer,
31+
activation_bits=8,
32+
input_quantization=False,
33+
output_quantization=False,
34+
weight_quantization=True,
35+
activation_quantization=False,
36+
)
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
return self.linear(x)
40+
41+
42+
class TestKmeansCodebook(unittest.TestCase):
43+
@unittest.skip("No module named 'faiss'")
44+
def test_kmeans_codebook(self) -> None:
45+
model = SimpleTestNetwork()
46+
codebook_before = copy.deepcopy(model.linear.weight_quantizer.codebook)
47+
kmeans_codebook(model)
48+
assert not torch.allclose(
49+
codebook_before,
50+
model.linear.weight_quantizer.codebook,
51+
)
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
from typing import Tuple
3+
4+
import torch
5+
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM
6+
7+
from torchao.prototype.quantization.module_swap.data_getters import LLMPTQDataGetter
8+
9+
test_config = LlamaConfig(
10+
vocab_size=10,
11+
hidden_size=32,
12+
num_hidden_layers=2,
13+
num_attention_heads=2,
14+
intermediate_size=64,
15+
)
16+
17+
18+
def get_test_llama_model_data() -> Tuple[LlamaForCausalLM, torch.Tensor]:
19+
model = LlamaForCausalLM(test_config)
20+
input_ids = torch.randint(0, test_config.vocab_size, (1, 10))
21+
return model, input_ids
22+
23+
24+
class TestPTQDataGetter(unittest.TestCase):
25+
@unittest.skip("TypeError: cannot unpack non-iterable NoneType object")
26+
def test_data_getter(self) -> None:
27+
model, data = get_test_llama_model_data()
28+
data_getter = LLMPTQDataGetter(model, data, 1)
29+
for name, module in model.named_modules():
30+
if isinstance(module, torch.nn.Linear):
31+
data = data_getter.pop(model, name)
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from torchao.prototype.quantization.module_swap.module_swap import (
7+
QuantizationRecipe,
8+
quantize_module_swap,
9+
)
10+
11+
12+
class SimpleEmbeddingTestNetwork(nn.Module):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
self.embedding = nn.Embedding(10, 64)
16+
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return self.embedding(x)
19+
20+
21+
class TestEmbeddingSwap(unittest.TestCase):
22+
def test_embedding_swap(self) -> None:
23+
model = SimpleEmbeddingTestNetwork()
24+
recipe = QuantizationRecipe()
25+
recipe.embedding_bits = 4
26+
recipe.embedding_quantization = True
27+
model = quantize_module_swap(model, recipe)
28+
x = torch.randint(0, 10, (10, 64))
29+
model(x)
30+
assert model.embedding.weight_quantizer.num_bits == 4
31+
assert model.embedding.weight_quantizer.group_size == 32
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

0 commit comments

Comments
 (0)