Skip to content

Commit eb25542

Browse files
committed
Add generic fake quantized linear for QAT
Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w ghstack-source-id: 2598aa9 Pull Request resolved: #1020
1 parent 4c2db34 commit eb25542

File tree

5 files changed

+558
-191
lines changed

5 files changed

+558
-191
lines changed

test/quantization/test_qat.py

Lines changed: 186 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,27 @@
1111
import unittest
1212

1313
import torch
14+
import torch.nn.functional as F
1415
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1516
from torchao.dtypes import (
1617
TensorCoreTiledLayoutType,
1718
)
1819
from torchao.quantization.prototype.qat.api import (
1920
ComposableQATQuantizer,
21+
FakeQuantizeConfig,
22+
QuantizationGranularity,
23+
)
24+
from torchao.quantization.prototype.qat.fake_quantizer import (
25+
FakeQuantizer,
26+
)
27+
from torchao.quantization.prototype.qat.linear import (
28+
FakeQuantizedLinear,
2029
)
2130
from torchao.quantization.prototype.qat.utils import (
2231
_choose_qparams_per_token_asymmetric,
2332
_fake_quantize_per_channel_group,
2433
_fake_quantize_per_token,
34+
_get_qmin_qmax,
2535
_GenericFakeQuantize,
2636
)
2737
from torchao.quantization.quant_api import (
@@ -92,15 +102,10 @@ def forward(self, x):
92102
class TestQAT(unittest.TestCase):
93103
SEED = 123
94104

95-
def _get_qmin_qmax(self, n_bit: int):
96-
qmin = -(2 ** (n_bit - 1))
97-
qmax = 2 ** (n_bit - 1) - 1
98-
return (qmin, qmax)
99-
100105
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
101106
def test_fake_quantize_per_channel_group(self):
102107
n_bit = 4
103-
(qmin, qmax) = self._get_qmin_qmax(n_bit)
108+
(qmin, qmax) = _get_qmin_qmax(n_bit)
104109
group_size = 128
105110

106111
torch.manual_seed(self.SEED)
@@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self):
126131

127132
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
128133
def test_fake_quantize_per_token(self):
129-
(qmin, qmax) = self._get_qmin_qmax(8)
134+
(qmin, qmax) = _get_qmin_qmax(8)
130135

131136
torch.manual_seed(self.SEED)
132137
x = torch.randn(100, 256).requires_grad_()
@@ -165,11 +170,11 @@ def _set_ptq_weight(
165170
Int4WeightOnlyQATLinear,
166171
)
167172
n_bit = 4
168-
(qmin, qmax) = self._get_qmin_qmax(n_bit)
173+
(qmin, qmax) = _get_qmin_qmax(n_bit)
174+
group_size = qat_linear.weight_fake_quantizer.config.group_size
169175
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
170176
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
171177
fp32_weight = qat_linear.weight
172-
group_size = qat_linear.groupsize
173178
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
174179
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
175180
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
@@ -180,7 +185,7 @@ def _set_ptq_weight(
180185
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
181186
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
182187
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
183-
qat_linear.weight, n_bit, qat_linear.groupsize,
188+
qat_linear.weight, n_bit, group_size,
184189
)
185190
q_weight = torch.ops.aten._convert_weight_to_int4pack(
186191
q_weight.to("cuda"), qat_linear.inner_k_tiles,
@@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self):
218223
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
219224
def test_qat_8da4w_quantizer(self):
220225
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
221-
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
226+
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
222227

223228
group_size = 16
224229
torch.manual_seed(self.SEED)
225230
m = M()
226231
m2 = copy.deepcopy(m)
227-
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
228-
module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
229-
subclass_model = subclass_quantizer.prepare(m)
230-
module_swap_model = module_swap_quantizer.prepare(m2)
232+
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
233+
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
234+
qat_model = qat_quantizer.prepare(m)
235+
ptq_model = ptq_quantizer.quantize(m2)
231236

232237
# Compare model values
233238
torch.manual_seed(self.SEED)
234239
x = m.example_inputs()
235240
x2 = copy.deepcopy(x)
236-
subclass_out = subclass_model(*x)
237-
module_swap_out = module_swap_model(*x2)
238-
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
241+
qat_out = qat_model(*x)
242+
ptq_out = ptq_model(*x2)
243+
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
239244

240245
# Convert QAT model and compare model values
241-
subclass_model = subclass_quantizer.convert(subclass_model)
242-
module_swap_model = module_swap_quantizer.convert(module_swap_model)
243-
subclass_out = subclass_model(*x)
244-
module_swap_out = module_swap_model(*x2)
245-
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
246+
converted_model = qat_quantizer.convert(qat_model)
247+
converted_out = converted_model(*x)
248+
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)
249+
250+
# Compare converted state dict
251+
ptq_state_dict = ptq_model.state_dict()
252+
converted_state_dict = converted_model.state_dict()
253+
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
254+
for k in ptq_state_dict.keys():
255+
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
246256

247257
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
248258
def test_qat_8da4w_quantizer_meta_weights(self):
@@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
275285
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
276286
qat_model = quantizer.prepare(m)
277287
qat_model.apply(disable_8da4w_fake_quant)
278-
self.assertFalse(qat_model.linear1._fake_quant_enabled)
279-
self.assertFalse(qat_model.linear2._fake_quant_enabled)
280-
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
288+
self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled)
289+
self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled)
290+
self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled)
291+
self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled)
292+
self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled)
293+
self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled)
281294

282295
# Disabled fake quant is just a normal linear
283296
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
@@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
292305

293306
# Renable fake quant
294307
qat_model.apply(enable_8da4w_fake_quant)
295-
self.assertTrue(qat_model.linear1._fake_quant_enabled)
296-
self.assertTrue(qat_model.linear2._fake_quant_enabled)
297-
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
308+
self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled)
309+
self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled)
310+
self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled)
311+
self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled)
312+
self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled)
313+
self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled)
298314

299315
# Fake quant should be applied as normal
300316
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
@@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self):
407423
the numerics of existing fake quantize ops in Pytorch in both
408424
the forward and the backward passes.
409425
"""
410-
(qmin, qmax) = self._get_qmin_qmax(4)
426+
(qmin, qmax) = _get_qmin_qmax(4)
411427
py_input = torch.randn(16, 64).float().requires_grad_()
412428
py_s = torch.randn(16).float()
413429
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
@@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self):
521537
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
522538
def test_qat_4w_quantizer(self):
523539
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
524-
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer
540+
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
525541

526542
group_size = 32
527543
inner_k_tiles = 8
@@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self):
530546
torch.manual_seed(self.SEED)
531547
m = M().to(device).to(dtype)
532548
m2 = copy.deepcopy(m)
533-
subclass_quantizer = Int4WeightOnlyQATQuantizer(
549+
qat_quantizer = Int4WeightOnlyQATQuantizer(
534550
groupsize=group_size, inner_k_tiles=inner_k_tiles,
535551
)
536-
module_swap_quantizer = Int4WeightOnlyQATQuantizer(
552+
ptq_quantizer = Int4WeightOnlyQuantizer(
537553
groupsize=group_size, inner_k_tiles=inner_k_tiles,
538554
)
539-
subclass_model = subclass_quantizer.prepare(m)
540-
module_swap_model = module_swap_quantizer.prepare(m2)
555+
qat_model = qat_quantizer.prepare(m)
556+
ptq_model = ptq_quantizer.quantize(m2)
541557

542558
# Compare model values
543559
torch.manual_seed(self.SEED)
544560
x = [i.to(device).to(dtype) for i in m.example_inputs()]
545561
x2 = copy.deepcopy(x)
546-
subclass_out = subclass_model(*x)
547-
module_swap_out = module_swap_model(*x2)
548-
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
562+
qat_out = qat_model(*x)
563+
ptq_out = ptq_model(*x2)
564+
self._assert_close_4w(qat_out, ptq_out)
549565

550566
# Convert QAT model and compare model values
551-
subclass_model = subclass_quantizer.convert(subclass_model)
552-
module_swap_model = module_swap_quantizer.convert(module_swap_model)
553-
subclass_out = subclass_model(*x)
554-
module_swap_out = module_swap_model(*x2)
555-
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
567+
converted_model = qat_quantizer.convert(qat_model)
568+
converted_out = converted_model(*x)
569+
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)
570+
571+
# Compare converted state dict
572+
ptq_state_dict = ptq_model.state_dict()
573+
converted_state_dict = converted_model.state_dict()
574+
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
575+
for k in ptq_state_dict.keys():
576+
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
556577

557578
class _MyQATQuantizer(TwoStepQuantizer):
558579
"""
@@ -603,5 +624,127 @@ def test_qat_4w_embedding(self):
603624
converted = quantizer.convert(model)
604625
converted_out = converted(*x)
605626

627+
def test_fake_quantize_config(self):
628+
"""
629+
Test initialization and property setting of `FakeQuantizeConfig`.
630+
"""
631+
# basic configs
632+
per_token_config = FakeQuantizeConfig(8, "per_token")
633+
self.assertEqual(per_token_config.bit_width, 8)
634+
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)
635+
self.assertIsNone(per_token_config.group_size)
636+
per_channel_config = FakeQuantizeConfig(4, "per_channel")
637+
self.assertEqual(per_channel_config.bit_width, 4)
638+
self.assertEqual(per_channel_config.granularity, QuantizationGranularity.PER_CHANNEL)
639+
self.assertIsNone(per_channel_config.group_size)
640+
641+
# initialize per_group config using only group size
642+
per_group_config = FakeQuantizeConfig(4, group_size=32)
643+
self.assertEqual(per_group_config.bit_width, 4)
644+
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
645+
self.assertEqual(per_group_config.group_size, 32)
646+
647+
# set granularity after initialization, should accept str as before
648+
per_group_config.granularity = "per_token"
649+
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)
650+
651+
# set group_size after initialization, should also update granularity
652+
per_group_config.group_size = 16
653+
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
654+
self.assertEqual(per_group_config.group_size, 16)
655+
656+
# bad config1: no granularity or group size provided
657+
with self.assertRaisesRegex(ValueError, "group_size or granularity must be set"):
658+
FakeQuantizeConfig(8)
659+
660+
# bad config2: 'per_group' but no group size
661+
with self.assertRaisesRegex(ValueError, "no group_size was set"):
662+
FakeQuantizeConfig(8, "per_group")
663+
664+
# bad config3: group size was set but granularity was not 'per_group'
665+
with self.assertRaisesRegex(ValueError, "group_size was set"):
666+
FakeQuantizeConfig(8, "per_token", group_size=16)
667+
668+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
669+
def test_fake_quantized_linear_8da4w(self):
670+
"""
671+
Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`.
672+
"""
673+
group_size = 128
674+
torch.manual_seed(self.SEED)
675+
fq_linear = FakeQuantizedLinear(
676+
256,
677+
688,
678+
bias=False,
679+
activation_config=FakeQuantizeConfig(8, "per_token", symmetric=False),
680+
weight_config=FakeQuantizeConfig(4, group_size=group_size),
681+
)
682+
683+
def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
684+
"""
685+
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
686+
"""
687+
# activations
688+
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
689+
(qmin, qmax) = _get_qmin_qmax(8)
690+
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)
691+
692+
# weights
693+
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
694+
zp = zp.to(torch.int32)
695+
(qmin, qmax) = _get_qmin_qmax(4)
696+
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
697+
return F.linear(x_fq, w_fq)
698+
699+
# Compare linear values
700+
torch.manual_seed(self.SEED)
701+
x = torch.randn(100, 256)
702+
x2 = copy.deepcopy(x)
703+
fq_out = fq_linear(x)
704+
baseline_out = linear_forward_8da4w(x2, fq_linear.weight)
705+
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
706+
707+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
708+
def test_fake_quantized_linear_4w(self):
709+
"""
710+
Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`.
711+
"""
712+
group_size = 128
713+
weight_config = FakeQuantizeConfig(
714+
bit_width=4,
715+
group_size=group_size,
716+
symmetric=False,
717+
zero_point_domain=ZeroPointDomain.FLOAT,
718+
)
719+
torch.manual_seed(self.SEED)
720+
fq_linear = FakeQuantizedLinear(
721+
256,
722+
688,
723+
bias=False,
724+
activation_config=None,
725+
weight_config=weight_config,
726+
)
727+
728+
def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
729+
"""
730+
Baseline for int4 weight only fake quantization that simulates the tinygemm kernel.
731+
"""
732+
(qmin, qmax) = _get_qmin_qmax(4, symmetric=False)
733+
(s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32)
734+
zp = zp.to(torch.int32)
735+
w_fq = _fake_quantize_per_channel_group(
736+
weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT,
737+
)
738+
return F.linear(x, w_fq)
739+
740+
# Compare linear values
741+
torch.manual_seed(self.SEED)
742+
x = torch.randn(100, 256)
743+
x2 = copy.deepcopy(x)
744+
fq_out = fq_linear(x)
745+
baseline_out = linear_forward_4w(x2, fq_linear.weight)
746+
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
747+
748+
606749
if __name__ == "__main__":
607750
unittest.main()

0 commit comments

Comments
 (0)