Skip to content

Commit 190b3ea

Browse files
committed
Autoquant
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3c1199d Pull Request resolved: #81
1 parent bed4cb4 commit 190b3ea

File tree

7 files changed

+531
-20
lines changed

7 files changed

+531
-20
lines changed

README.md

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# torchao: PyTorch Architecture Optimization
1+
# torchao: PyTorch Architecture Optimization
22

33
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**
44

5-
The `torchao` package allows you to quantize and prune your models using native PyTorch.
5+
The `torchao` package allows you to quantize and prune your models using native PyTorch.
66

77
The repo hosts both
88
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
@@ -38,31 +38,43 @@ pip install -e .
3838

3939
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.
4040

41-
### A8W8 Dynamic Quantization
41+
### Autoquantization
4242

43-
```Python
43+
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
44+
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
45+
46+
```python
4447
import torch
45-
from torchao.quantization import quant_api
48+
import torchao
4649

47-
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
48-
torch._inductor.config.force_fuse_int_mm_with_mul = True
50+
# inductor settings which improve torch.compile performance for quantized modules
51+
torch._inductor.config.force_fuse_int_mm_with_mul
52+
torch._inductor.config.use_mixed_mm
4953

5054
# Plug in your model and example input
5155
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
5256
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
5357

54-
# convert linear modules to quantized linear modules
55-
quant_api.change_linear_weights_to_int8_dqtensors(model)
58+
# perform autoquantization
59+
torchao.autoquant(model, (input))
5660

5761
# compile the model to improve performance
5862
model = torch.compile(model, mode='max-autotune')
5963
model(input)
6064
```
6165

66+
67+
### A8W8 Dynamic Quantization
68+
69+
```python
70+
# convert linear modules to quantized linear modules
71+
torchao.change_linear_weights_to_int8_dqtensors(model)
72+
```
73+
6274
### A16W8 WeightOnly Quantization
6375

6476
```python
65-
quant_api.change_linear_weights_to_int8_woqtensors(model)
77+
torchao.change_linear_weights_to_int8_woqtensors(model)
6678
```
6779

6880
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
@@ -71,7 +83,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
7183
### A16W4 WeightOnly Quantization
7284

7385
```python
74-
quant_api.change_linear_weights_to_int4_woqtensors(model)
86+
torchao.change_linear_weights_to_int4_woqtensors(model)
7587
```
7688

7789
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

test/test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn as nn
1313
from torch._inductor.utils import run_and_get_code
1414
from torch._dynamo import config
15+
import torchao
1516
from torch.ao.quantization import MinMaxObserver, QConfigMapping
1617

1718
from torchao.quantization.dynamic_quant import (
@@ -54,6 +55,13 @@
5455
_fqn_to_op_to_shape_to_count,
5556
LoggingTensorMode,
5657
)
58+
from torchao.quantization.autoquant import (
59+
AQInt8DynamicallyQuantizedLinearWeight,
60+
AQWeightOnlyQuantizedLinearWeight,
61+
AQWeightOnlyQuantizedLinearWeight2,
62+
AQWeightOnlyQuantizedLinearWeight3
63+
64+
)
5765
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5866
import os
5967

@@ -880,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
880888
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
881889
)
882890

891+
def test_aq_int8_dynamic_quant_subclass(self):
892+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
893+
self._test_lin_weight_subclass_impl(
894+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
895+
)
896+
897+
def test_aq_int8_weight_only_quant_subclass(self):
898+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
899+
self._test_lin_weight_subclass_impl(
900+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
901+
)
902+
903+
def test_aq_int8_weight_only_quant_subclass(self):
904+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
905+
self._test_lin_weight_subclass_impl(
906+
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
907+
)
908+
909+
def test_aq_int8_weight_only_quant_2_subclass(self):
910+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
911+
self._test_lin_weight_subclass_impl(
912+
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
913+
)
914+
915+
def test_aq_int8_weight_only_quant_3_subclass(self):
916+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
917+
self._test_lin_weight_subclass_impl(
918+
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
919+
)
920+
883921
def test_int4_weight_only_quant_subclass(self):
884922
self._test_lin_weight_subclass_impl(
885923
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
@@ -1195,6 +1233,51 @@ def test_on_dummy_distilbert(self):
11951233
print("sqnr_pt_quant", sqnr_pt_quant)
11961234
self.assertTrue(sqnr_sq >= 8.0)
11971235

1236+
class TestAutoQuant(unittest.TestCase):
1237+
def test_autoquant_one_input(self):
1238+
torch._inductor.config.epilogue_fusion = False
1239+
torch._inductor.config.use_mixed_mm = True
1240+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1241+
torch._dynamo.config.automatic_dynamic_shapes = False
1242+
1243+
for m,k,n in [
1244+
(1, 1024, 1024),
1245+
(64, 1024, 1024),
1246+
(2**15, 1024, 1024),
1247+
(1, 1024, 4096),
1248+
(64, 1024, 4096),
1249+
(1, 4096, 1024),
1250+
(64, 4096, 1024),
1251+
(4096, 4096, 1024),
1252+
]:
1253+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1254+
model = torch.nn.Sequential(
1255+
torch.nn.ReLU(),
1256+
torch.nn.Linear(k,n),
1257+
torch.nn.ReLU(),
1258+
).to("cuda").to(torch.bfloat16)
1259+
out = model(example_input)
1260+
torchao.autoquant(model, example_input)
1261+
out2 = model(example_input)
1262+
sqnr = SQNR(out, out2)
1263+
self.assertTrue(sqnr >= 30)
1264+
1265+
def test_autoquant_multi_input(self):
1266+
m1, m2, k, n = 1, 8, 1024, 1024
1267+
model = torch.nn.Sequential(
1268+
torch.nn.ReLU(),
1269+
torch.nn.Linear(k,n),
1270+
torch.nn.ReLU(),
1271+
).cuda().to(torch.bfloat16)
1272+
example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16)
1273+
example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16)
1274+
torchao.change_linears_to_autoquantizable(model)
1275+
out=model(example_input)
1276+
model(example_input2)
1277+
torchao.change_autoquantizable_to_quantized(model)
1278+
out2 = model(example_input)
1279+
sqnr = SQNR(out, out2)
1280+
self.assertTrue(sqnr >= 30)
11981281

11991282
if __name__ == "__main__":
12001283
unittest.main()

torchao/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
1+
from torchao.quantization import (
2+
apply_weight_only_int8_quant,
3+
apply_dynamic_quant,
4+
change_linear_weights_to_int8_dqtensors,
5+
change_linear_weights_to_int8_woqtensors,
6+
change_linear_weights_to_int4_woqtensors,
7+
swap_conv2d_1x1_to_linear,
8+
autoquant,
9+
change_linears_to_autoquantizable,
10+
change_autoquantizable_to_quantized,
11+
)
112
from . import dtypes
2-
from .quantization.quant_api import apply_dynamic_quant
3-
from .quantization.quant_api import apply_weight_only_int8_quant
413

514
__all__ = [
6-
"dtypes",
7-
"apply_dynamic_quant",
15+
"dtypes",
16+
"apply_dynamic_quant",
17+
"apply_weight_only_int8_quant",
18+
"apply_dynamic_quant",
19+
"change_linear_weights_to_int8_dqtensors",
20+
"change_linear_weights_to_int8_woqtensors",
21+
"change_linear_weights_to_int4_woqtensors",
22+
"swap_conv2d_1x1_to_linear"
23+
"safe_int_mm",
24+
"autoquant",
25+
"change_linears_to_autoquantizable",
26+
"change_autoquantizable_to_quantized",
27+
"dtypes"
828
]

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

0 commit comments

Comments
 (0)