Skip to content

Commit 338d87c

Browse files
Refactor int4 and int8 weight only quantization to use quantize (#301)
* Replace implementation for int8 dynamic quantization with call to `quantize` Summary: Previously we added `quantize` as a general API (#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags: * Refactor int8 weight only quant to use `quantize` Summary: Similar to #294 we replaced the implementation of int8 weight only quant to used the newly added `quantize` function, as a part of the unification effort for affine quantization Test Plan: 1. unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756 elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629 elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368 2. integration test: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Reference: elapsed_time: 1.355208740234375 milliseconds After refactor: elapsed_time: 1.32778857421875 milliseconds code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845 Reviewers: Subscribers: Tasks: Tags: * Replace implementation for int8 dynamic quantization with call to `quantize` Summary: Previously we added `quantize` as a general API (#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags: * Refactor int4 weight only quantization with call to `quantize` Summary: This is similar to #294 but applied for int4 weight only quantization Test Plan: unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297 elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314 elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793 integration perf test: reference: elapsed_time: 2.5900126953125 milliseconds after refactor: elapsed_time: 2.56680078125 milliseconds diff: no diff TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Before: After: generated code diff: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
1 parent 729fa4d commit 338d87c

File tree

8 files changed

+397
-169
lines changed

8 files changed

+397
-169
lines changed

benchmarks/benchmark_aq.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2+
"""
3+
import torch
4+
from torchao.quantization.subclass import (
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
Int4WeightOnlyQuantizedLinearWeight,
7+
)
8+
from torchao.quantization.utils import (
9+
TORCH_VERSION_AFTER_2_4,
10+
)
11+
from torchao.quantization.quant_api import (
12+
_replace_with_custom_fn_if_matches_filter,
13+
)
14+
import copy
15+
16+
class ToyLinearModel(torch.nn.Module):
17+
def __init__(self, m=64, n=32, k=64):
18+
super().__init__()
19+
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
20+
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
21+
22+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
23+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
24+
25+
def forward(self, x):
26+
x = self.linear1(x)
27+
x = self.linear2(x)
28+
return x
29+
30+
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
31+
"""
32+
The deprecated implementation for int8 dynamic quant API, used as a reference for
33+
numerics and performance
34+
"""
35+
from torchao.quantization.quant_api import _in_features_greater_than_16
36+
from torchao.quantization.quant_api import _is_linear
37+
from torchao.quantization.quant_api import _get_subclass_inserter
38+
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
39+
40+
if filter_fn is None:
41+
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
42+
*args
43+
)
44+
45+
_replace_with_custom_fn_if_matches_filter(
46+
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
47+
)
48+
49+
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
50+
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
51+
"""
52+
The deprecated implementation for weight only quant API, used as a reference for
53+
numerics and performance
54+
"""
55+
from torchao.quantization.quant_api import _is_linear
56+
from torchao.quantization.quant_api import _get_subclass_inserter
57+
58+
filter_fn = kwargs.pop("filter_fn", _is_linear)
59+
60+
_replace_with_custom_fn_if_matches_filter(
61+
model,
62+
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
63+
filter_fn,
64+
)
65+
66+
return _ref_change_linear_weights_to_woqtensors
67+
68+
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
69+
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
70+
71+
72+
def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
73+
if kwargs is None:
74+
kwargs = {}
75+
76+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
77+
m_ref = copy.deepcopy(m)
78+
# setting batch_size to 20 to be compatible with the kernel
79+
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
80+
81+
api(m, **kwargs)
82+
83+
# reference
84+
ref_api(m_ref, **kwargs)
85+
86+
res = m(*example_inputs)
87+
ref = m_ref(*example_inputs)
88+
89+
assert torch.equal(res, ref)
90+
91+
# perf comparison
92+
from torchao.utils import benchmark_model
93+
# warmup
94+
WARMUP = 5
95+
RUNS = 100
96+
input_tensor = example_inputs[0]
97+
m = torch.compile(m, mode='max-autotune', fullgraph=True)
98+
99+
benchmark_model(m, WARMUP, input_tensor)
100+
elapsed_time = benchmark_model(m, RUNS, input_tensor)
101+
102+
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
103+
benchmark_model(m_ref, WARMUP, input_tensor)
104+
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
105+
106+
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
107+
assert elapsed_time < 1.05 * ref_elapsed_time
108+
109+
if __name__ == "__main__" and TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available():
110+
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
111+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)
112+
113+
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
114+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)
115+
116+
kwargs = {"groupsize": 32}
117+
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
118+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)

test/integration/test_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def _test_lin_weight_subclass_impl(
930930
)
931931

932932
@parameterized.expand(COMMON_DEVICE_DTYPE)
933+
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
933934
def test_int8_dynamic_quant_subclass(self, device, dtype):
934935
self._test_lin_weight_subclass_impl(
935936
Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
@@ -1217,6 +1218,8 @@ def forward(self, x):
12171218
@parameterized.expand(COMMON_DEVICE_DTYPE)
12181219
@torch.no_grad()
12191220
def test_save_load_dqtensors(self, device, dtype):
1221+
if device == "cpu":
1222+
self.skipTest(f"indcutor failed for cpu right now")
12201223
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
12211224

12221225
@parameterized.expand(COMMON_DEVICE_DTYPE)

test/quantization/test_quant_api.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torchao.quantization.subclass import (
3030
to_laq,
3131
LinearActQuantizedTensor,
32+
Int8WeightOnlyQuantizedLinearWeight,
33+
Int4WeightOnlyQuantizedLinearWeight,
3234
)
3335
from torchao.quantization.quant_api import (
3436
_replace_with_custom_fn_if_matches_filter,
@@ -138,6 +140,28 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
138140
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
139141
)
140142

143+
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
144+
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
145+
"""
146+
The deprecated implementation for weight only quant API, used as a reference for
147+
numerics and performance
148+
"""
149+
from torchao.quantization.quant_api import _is_linear
150+
from torchao.quantization.quant_api import _get_subclass_inserter
151+
152+
filter_fn = kwargs.pop("filter_fn", _is_linear)
153+
154+
_replace_with_custom_fn_if_matches_filter(
155+
model,
156+
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
157+
filter_fn,
158+
)
159+
160+
return _ref_change_linear_weights_to_woqtensors
161+
162+
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
163+
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
164+
141165
class TestQuantFlow(unittest.TestCase):
142166
def test_dynamic_quant_gpu_singleline(self):
143167
m = ToyLinearModel().eval()
@@ -478,8 +502,7 @@ def test_quantized_tensor_subclass_int4(self):
478502
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
479503

480504
# reference
481-
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
482-
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
505+
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
483506

484507
res = m(*example_inputs)
485508
ref = m_copy(*example_inputs)
@@ -489,7 +512,7 @@ def test_quantized_tensor_subclass_int4(self):
489512

490513
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
491514
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
492-
def test_quantized_tensor_subclass_int8(self):
515+
def test_quantized_tensor_subclass_int8_wo(self):
493516
m = ToyLinearModel().eval().to(torch.bfloat16)
494517
m_copy = copy.deepcopy(m)
495518
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
@@ -500,13 +523,13 @@ def test_quantized_tensor_subclass_int8(self):
500523
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
501524

502525
# reference
503-
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
504-
change_linear_weights_to_int8_woqtensors(m_copy)
526+
_ref_change_linear_weights_to_int8_woqtensors(m_copy)
527+
505528

506529
res = m(*example_inputs)
507530
ref = m_copy(*example_inputs)
508531

509-
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
532+
self.assertTrue(torch.equal(res, ref))
510533

511534

512535
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@@ -525,8 +548,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
525548
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
526549

527550
# reference
528-
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
529-
change_linear_weights_to_int8_dqtensors(m_copy)
551+
_ref_change_linear_weights_to_int8_dqtensors(m_copy)
530552

531553
res = m(*example_inputs)
532554
ref = m_copy(*example_inputs)
@@ -545,45 +567,5 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
545567
# make sure it compiles
546568
torch._export.aot_compile(m_unwrapped, example_inputs)
547569

548-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
549-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
550-
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
551-
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
552-
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
553-
m_ref = copy.deepcopy(m)
554-
# setting batch_size to 20 to be compatible with the kernel
555-
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
556-
557-
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
558-
change_linear_weights_to_int8_dqtensors(m)
559-
560-
# reference
561-
_ref_change_linear_weights_to_int8_dqtensors(m_ref)
562-
563-
res = m(*example_inputs)
564-
ref = m_ref(*example_inputs)
565-
566-
self.assertTrue(torch.equal(res, ref))
567-
568-
# perf comparison
569-
from torchao.utils import benchmark_model
570-
# warmup
571-
WARMUP = 5
572-
RUNS = 100
573-
input_tensor = example_inputs[0]
574-
m = torch.compile(m, mode='max-autotune', fullgraph=True)
575-
576-
benchmark_model(m, WARMUP, input_tensor)
577-
elapsed_time = benchmark_model(m, RUNS, input_tensor)
578-
579-
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
580-
benchmark_model(m_ref, WARMUP, input_tensor)
581-
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
582-
583-
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
584-
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)
585-
586-
587-
588570
if __name__ == "__main__":
589571
unittest.main()

0 commit comments

Comments
 (0)