|
12 | 12 | import torch.nn as nn
|
13 | 13 | from torch._inductor.utils import run_and_get_code
|
14 | 14 | from torch._dynamo import config
|
| 15 | +import torchao |
15 | 16 | from torch.ao.quantization import MinMaxObserver, QConfigMapping
|
16 | 17 |
|
17 | 18 | from torchao.quantization.dynamic_quant import (
|
|
54 | 55 | _fqn_to_op_to_shape_to_count,
|
55 | 56 | LoggingTensorMode,
|
56 | 57 | )
|
| 58 | +from torchao.quantization.autoquant import ( |
| 59 | + AQInt8DynamicallyQuantizedLinearWeight, |
| 60 | + AQWeightOnlyQuantizedLinearWeight, |
| 61 | + AQWeightOnlyQuantizedLinearWeight2, |
| 62 | + AQWeightOnlyQuantizedLinearWeight3 |
| 63 | + |
| 64 | +) |
57 | 65 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
58 | 66 | import os
|
59 | 67 |
|
@@ -880,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
|
880 | 888 | Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
|
881 | 889 | )
|
882 | 890 |
|
| 891 | + def test_aq_int8_dynamic_quant_subclass(self): |
| 892 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 893 | + self._test_lin_weight_subclass_impl( |
| 894 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
| 895 | + ) |
| 896 | + |
| 897 | + def test_aq_int8_weight_only_quant_subclass(self): |
| 898 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 899 | + self._test_lin_weight_subclass_impl( |
| 900 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
| 901 | + ) |
| 902 | + |
| 903 | + def test_aq_int8_weight_only_quant_subclass(self): |
| 904 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 905 | + self._test_lin_weight_subclass_impl( |
| 906 | + AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype |
| 907 | + ) |
| 908 | + |
| 909 | + def test_aq_int8_weight_only_quant_2_subclass(self): |
| 910 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 911 | + self._test_lin_weight_subclass_impl( |
| 912 | + AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype |
| 913 | + ) |
| 914 | + |
| 915 | + def test_aq_int8_weight_only_quant_3_subclass(self): |
| 916 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 917 | + self._test_lin_weight_subclass_impl( |
| 918 | + AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype |
| 919 | + ) |
| 920 | + |
883 | 921 | def test_int4_weight_only_quant_subclass(self):
|
884 | 922 | self._test_lin_weight_subclass_impl(
|
885 | 923 | Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
|
@@ -1195,6 +1233,51 @@ def test_on_dummy_distilbert(self):
|
1195 | 1233 | print("sqnr_pt_quant", sqnr_pt_quant)
|
1196 | 1234 | self.assertTrue(sqnr_sq >= 8.0)
|
1197 | 1235 |
|
| 1236 | +class TestAutoQuant(unittest.TestCase): |
| 1237 | + def test_autoquant_one_input(self): |
| 1238 | + torch._inductor.config.epilogue_fusion = False |
| 1239 | + torch._inductor.config.use_mixed_mm = True |
| 1240 | + torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 1241 | + torch._dynamo.config.automatic_dynamic_shapes = False |
| 1242 | + |
| 1243 | + for m,k,n in [ |
| 1244 | + (1, 1024, 1024), |
| 1245 | + (64, 1024, 1024), |
| 1246 | + (2**15, 1024, 1024), |
| 1247 | + (1, 1024, 4096), |
| 1248 | + (64, 1024, 4096), |
| 1249 | + (1, 4096, 1024), |
| 1250 | + (64, 4096, 1024), |
| 1251 | + (4096, 4096, 1024), |
| 1252 | + ]: |
| 1253 | + example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) |
| 1254 | + model = torch.nn.Sequential( |
| 1255 | + torch.nn.ReLU(), |
| 1256 | + torch.nn.Linear(k,n), |
| 1257 | + torch.nn.ReLU(), |
| 1258 | + ).to("cuda").to(torch.bfloat16) |
| 1259 | + out = model(example_input) |
| 1260 | + torchao.autoquant(model, example_input) |
| 1261 | + out2 = model(example_input) |
| 1262 | + sqnr = SQNR(out, out2) |
| 1263 | + self.assertTrue(sqnr >= 30) |
| 1264 | + |
| 1265 | + def test_autoquant_multi_input(self): |
| 1266 | + m1, m2, k, n = 1, 8, 1024, 1024 |
| 1267 | + model = torch.nn.Sequential( |
| 1268 | + torch.nn.ReLU(), |
| 1269 | + torch.nn.Linear(k,n), |
| 1270 | + torch.nn.ReLU(), |
| 1271 | + ).cuda().to(torch.bfloat16) |
| 1272 | + example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16) |
| 1273 | + example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16) |
| 1274 | + torchao.change_linears_to_autoquantizable(model) |
| 1275 | + out=model(example_input) |
| 1276 | + model(example_input2) |
| 1277 | + torchao.change_autoquantizable_to_quantized(model) |
| 1278 | + out2 = model(example_input) |
| 1279 | + sqnr = SQNR(out, out2) |
| 1280 | + self.assertTrue(sqnr >= 30) |
1198 | 1281 |
|
1199 | 1282 | if __name__ == "__main__":
|
1200 | 1283 | unittest.main()
|
0 commit comments