Skip to content

Commit 2ed4d2d

Browse files
committed
Add module swap quantization API from Quanty
**Summary:** This commit adds a module-swap-based PTQ API from Quanty, including: - Quantized linear and embedding modules - `IntQuantizer` to specify how to quantize weights and activations - `CodeBookQuantizer` as an alternative to IntQuantizer - Implementation of K-means to be used for codebook quantization - Range setting and data getter utility These new APIs will complement our existing `quantize_` API, which is primarily used for tensor-subclass-based quantization today (though it can also support module swaps). All APIs introduced in this commit are under prototype and highly subject to change. In particular, we plan to delete `quantize_module_swap` and `QuantizationRecipe`, and instead integrate this flow with the `quantize_` API by creating a new `AOBaseConfig`. All code is migrated from Quanty and written by @TiRune. **Test Plan:** python test/quantization/module_swap/test_*
1 parent 8c81863 commit 2ed4d2d

19 files changed

+2243
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import copy
2+
import unittest
3+
from typing import Union
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
from torchao.prototype.quantization.module_swap import (
9+
CodeBookQuantizer,
10+
QuantizedLinear,
11+
)
12+
from torchao.prototype.quantization.module_swap.algorithms import kmeans_codebook
13+
14+
15+
class SimpleTestNetwork(nn.Module):
16+
def __init__(self, weight_group_size: Union[int, str] = "per_channel") -> None:
17+
super().__init__()
18+
if weight_group_size == "per_channel":
19+
weight_group_size = 8
20+
assert isinstance(weight_group_size, int)
21+
weight_quantizer = CodeBookQuantizer(
22+
n_bits=2,
23+
features=16,
24+
codebook_dim=2,
25+
)
26+
27+
self.linear = QuantizedLinear(
28+
in_features=16,
29+
out_features=8,
30+
bias=False,
31+
weight_quantizer=weight_quantizer,
32+
activation_bits=8,
33+
input_quantization=False,
34+
output_quantization=False,
35+
weight_quantization=True,
36+
activation_quantization=False,
37+
)
38+
39+
def forward(self, x: torch.Tensor) -> torch.Tensor:
40+
return self.linear(x)
41+
42+
43+
class TestKmeansCodebook(unittest.TestCase):
44+
@unittest.skip("No module named 'faiss'")
45+
def test_kmeans_codebook(self) -> None:
46+
model = SimpleTestNetwork()
47+
codebook_before = copy.deepcopy(model.linear.weight_quantizer.codebook)
48+
kmeans_codebook(model)
49+
assert not torch.allclose(
50+
codebook_before,
51+
model.linear.weight_quantizer.codebook,
52+
)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
from typing import Tuple
3+
4+
import torch
5+
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM
6+
7+
from torchao.prototype.quantization.module_swap.data_getters import LLMPTQDataGetter
8+
9+
test_config = LlamaConfig(
10+
vocab_size=10,
11+
hidden_size=32,
12+
num_hidden_layers=2,
13+
num_attention_heads=2,
14+
intermediate_size=64,
15+
)
16+
17+
18+
def get_test_llama_model_data() -> Tuple[LlamaForCausalLM, torch.Tensor]:
19+
model = LlamaForCausalLM(test_config)
20+
input_ids = torch.randint(0, test_config.vocab_size, (1, 10))
21+
return model, input_ids
22+
23+
24+
class TestPTQDataGetter(unittest.TestCase):
25+
@unittest.skip("TypeError: cannot unpack non-iterable NoneType object")
26+
def test_data_getter(self) -> None:
27+
model, data = get_test_llama_model_data()
28+
data_getter = LLMPTQDataGetter(model, data, 1)
29+
for name, module in model.named_modules():
30+
if isinstance(module, torch.nn.Linear):
31+
data = data_getter.pop(model, name)
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from torchao.prototype.quantization.module_swap import (
7+
QuantizationRecipe,
8+
quantize_module_swap,
9+
)
10+
11+
12+
class SimpleEmbeddingTestNetwork(nn.Module):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
self.embedding = nn.Embedding(10, 64)
16+
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return self.embedding(x)
19+
20+
21+
class TestEmbeddingSwap(unittest.TestCase):
22+
def test_embedding_swap(self) -> None:
23+
model = SimpleEmbeddingTestNetwork()
24+
recipe = QuantizationRecipe()
25+
recipe.embedding_bits = 4
26+
recipe.embedding_quantization = True
27+
model = quantize_module_swap(model, recipe)
28+
x = torch.randint(0, 10, (10, 64))
29+
model(x)
30+
assert model.embedding.weight_quantizer.num_bits == 4
31+
assert model.embedding.weight_quantizer.group_size == 32
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

0 commit comments

Comments
 (0)