11
11
import unittest
12
12
13
13
import torch
14
+ import torch .nn .functional as F
14
15
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
15
16
from torchao .dtypes import (
16
17
TensorCoreTiledLayoutType ,
17
18
)
18
19
from torchao .quantization .prototype .qat .api import (
19
20
ComposableQATQuantizer ,
21
+ FakeQuantizeConfig ,
22
+ QuantizationGranularity ,
23
+ )
24
+ from torchao .quantization .prototype .qat .fake_quantizer import (
25
+ FakeQuantizer ,
26
+ )
27
+ from torchao .quantization .prototype .qat .linear import (
28
+ FakeQuantizedLinear ,
20
29
)
21
30
from torchao .quantization .prototype .qat .utils import (
22
31
_choose_qparams_per_token_asymmetric ,
23
32
_fake_quantize_per_channel_group ,
24
33
_fake_quantize_per_token ,
34
+ _get_qmin_qmax ,
25
35
_GenericFakeQuantize ,
26
36
)
27
37
from torchao .quantization .quant_api import (
@@ -92,15 +102,10 @@ def forward(self, x):
92
102
class TestQAT (unittest .TestCase ):
93
103
SEED = 123
94
104
95
- def _get_qmin_qmax (self , n_bit : int ):
96
- qmin = - (2 ** (n_bit - 1 ))
97
- qmax = 2 ** (n_bit - 1 ) - 1
98
- return (qmin , qmax )
99
-
100
105
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
101
106
def test_fake_quantize_per_channel_group (self ):
102
107
n_bit = 4
103
- (qmin , qmax ) = self . _get_qmin_qmax (n_bit )
108
+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
104
109
group_size = 128
105
110
106
111
torch .manual_seed (self .SEED )
@@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self):
126
131
127
132
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
128
133
def test_fake_quantize_per_token (self ):
129
- (qmin , qmax ) = self . _get_qmin_qmax (8 )
134
+ (qmin , qmax ) = _get_qmin_qmax (8 )
130
135
131
136
torch .manual_seed (self .SEED )
132
137
x = torch .randn (100 , 256 ).requires_grad_ ()
@@ -165,11 +170,11 @@ def _set_ptq_weight(
165
170
Int4WeightOnlyQATLinear ,
166
171
)
167
172
n_bit = 4
168
- (qmin , qmax ) = self ._get_qmin_qmax (n_bit )
173
+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
174
+ group_size = qat_linear .weight_fake_quantizer .config .group_size
169
175
if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
170
176
assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
171
177
fp32_weight = qat_linear .weight
172
- group_size = qat_linear .groupsize
173
178
(s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
174
179
q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
175
180
fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
@@ -180,7 +185,7 @@ def _set_ptq_weight(
180
185
elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
181
186
assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
182
187
(q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
183
- qat_linear .weight , n_bit , qat_linear . groupsize ,
188
+ qat_linear .weight , n_bit , group_size ,
184
189
)
185
190
q_weight = torch .ops .aten ._convert_weight_to_int4pack (
186
191
q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
@@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self):
218
223
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
219
224
def test_qat_8da4w_quantizer (self ):
220
225
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
221
- from torchao .quantization .prototype . qat . linear import Int8DynActInt4WeightQATQuantizer
226
+ from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
222
227
223
228
group_size = 16
224
229
torch .manual_seed (self .SEED )
225
230
m = M ()
226
231
m2 = copy .deepcopy (m )
227
- subclass_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
228
- module_swap_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
229
- subclass_model = subclass_quantizer .prepare (m )
230
- module_swap_model = module_swap_quantizer . prepare (m2 )
232
+ qat_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
233
+ ptq_quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
234
+ qat_model = qat_quantizer .prepare (m )
235
+ ptq_model = ptq_quantizer . quantize (m2 )
231
236
232
237
# Compare model values
233
238
torch .manual_seed (self .SEED )
234
239
x = m .example_inputs ()
235
240
x2 = copy .deepcopy (x )
236
- subclass_out = subclass_model (* x )
237
- module_swap_out = module_swap_model (* x2 )
238
- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
241
+ qat_out = qat_model (* x )
242
+ ptq_out = ptq_model (* x2 )
243
+ torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
239
244
240
245
# Convert QAT model and compare model values
241
- subclass_model = subclass_quantizer .convert (subclass_model )
242
- module_swap_model = module_swap_quantizer .convert (module_swap_model )
243
- subclass_out = subclass_model (* x )
244
- module_swap_out = module_swap_model (* x2 )
245
- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
246
+ converted_model = qat_quantizer .convert (qat_model )
247
+ converted_out = converted_model (* x )
248
+ torch .testing .assert_close (ptq_out , converted_out , atol = 0 , rtol = 0 )
249
+
250
+ # Compare converted state dict
251
+ ptq_state_dict = ptq_model .state_dict ()
252
+ converted_state_dict = converted_model .state_dict ()
253
+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
254
+ for k in ptq_state_dict .keys ():
255
+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
246
256
247
257
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
248
258
def test_qat_8da4w_quantizer_meta_weights (self ):
@@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
275
285
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
276
286
qat_model = quantizer .prepare (m )
277
287
qat_model .apply (disable_8da4w_fake_quant )
278
- self .assertFalse (qat_model .linear1 ._fake_quant_enabled )
279
- self .assertFalse (qat_model .linear2 ._fake_quant_enabled )
280
- self .assertFalse (qat_model .sub .linear ._fake_quant_enabled )
288
+ self .assertFalse (qat_model .linear1 .activation_fake_quantizer .enabled )
289
+ self .assertFalse (qat_model .linear1 .weight_fake_quantizer .enabled )
290
+ self .assertFalse (qat_model .linear2 .activation_fake_quantizer .enabled )
291
+ self .assertFalse (qat_model .linear2 .weight_fake_quantizer .enabled )
292
+ self .assertFalse (qat_model .sub .linear .activation_fake_quantizer .enabled )
293
+ self .assertFalse (qat_model .sub .linear .weight_fake_quantizer .enabled )
281
294
282
295
# Disabled fake quant is just a normal linear
283
296
m2 .linear1 .weight = torch .nn .Parameter (qat_model .linear1 .weight )
@@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
292
305
293
306
# Renable fake quant
294
307
qat_model .apply (enable_8da4w_fake_quant )
295
- self .assertTrue (qat_model .linear1 ._fake_quant_enabled )
296
- self .assertTrue (qat_model .linear2 ._fake_quant_enabled )
297
- self .assertTrue (qat_model .sub .linear ._fake_quant_enabled )
308
+ self .assertTrue (qat_model .linear1 .activation_fake_quantizer .enabled )
309
+ self .assertTrue (qat_model .linear1 .weight_fake_quantizer .enabled )
310
+ self .assertTrue (qat_model .linear2 .activation_fake_quantizer .enabled )
311
+ self .assertTrue (qat_model .linear2 .weight_fake_quantizer .enabled )
312
+ self .assertTrue (qat_model .sub .linear .activation_fake_quantizer .enabled )
313
+ self .assertTrue (qat_model .sub .linear .weight_fake_quantizer .enabled )
298
314
299
315
# Fake quant should be applied as normal
300
316
quantizer2 = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
@@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self):
407
423
the numerics of existing fake quantize ops in Pytorch in both
408
424
the forward and the backward passes.
409
425
"""
410
- (qmin , qmax ) = self . _get_qmin_qmax (4 )
426
+ (qmin , qmax ) = _get_qmin_qmax (4 )
411
427
py_input = torch .randn (16 , 64 ).float ().requires_grad_ ()
412
428
py_s = torch .randn (16 ).float ()
413
429
py_zp = torch .randint (qmax , size = (16 ,), dtype = torch .int32 )
@@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self):
521
537
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
522
538
def test_qat_4w_quantizer (self ):
523
539
from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
524
- from torchao .quantization .prototype . qat . linear import Int4WeightOnlyQATQuantizer
540
+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
525
541
526
542
group_size = 32
527
543
inner_k_tiles = 8
@@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self):
530
546
torch .manual_seed (self .SEED )
531
547
m = M ().to (device ).to (dtype )
532
548
m2 = copy .deepcopy (m )
533
- subclass_quantizer = Int4WeightOnlyQATQuantizer (
549
+ qat_quantizer = Int4WeightOnlyQATQuantizer (
534
550
groupsize = group_size , inner_k_tiles = inner_k_tiles ,
535
551
)
536
- module_swap_quantizer = Int4WeightOnlyQATQuantizer (
552
+ ptq_quantizer = Int4WeightOnlyQuantizer (
537
553
groupsize = group_size , inner_k_tiles = inner_k_tiles ,
538
554
)
539
- subclass_model = subclass_quantizer .prepare (m )
540
- module_swap_model = module_swap_quantizer . prepare (m2 )
555
+ qat_model = qat_quantizer .prepare (m )
556
+ ptq_model = ptq_quantizer . quantize (m2 )
541
557
542
558
# Compare model values
543
559
torch .manual_seed (self .SEED )
544
560
x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
545
561
x2 = copy .deepcopy (x )
546
- subclass_out = subclass_model (* x )
547
- module_swap_out = module_swap_model (* x2 )
548
- torch . testing . assert_close ( subclass_out , module_swap_out , atol = 0 , rtol = 0 )
562
+ qat_out = qat_model (* x )
563
+ ptq_out = ptq_model (* x2 )
564
+ self . _assert_close_4w ( qat_out , ptq_out )
549
565
550
566
# Convert QAT model and compare model values
551
- subclass_model = subclass_quantizer .convert (subclass_model )
552
- module_swap_model = module_swap_quantizer .convert (module_swap_model )
553
- subclass_out = subclass_model (* x )
554
- module_swap_out = module_swap_model (* x2 )
555
- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
567
+ converted_model = qat_quantizer .convert (qat_model )
568
+ converted_out = converted_model (* x )
569
+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
570
+
571
+ # Compare converted state dict
572
+ ptq_state_dict = ptq_model .state_dict ()
573
+ converted_state_dict = converted_model .state_dict ()
574
+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
575
+ for k in ptq_state_dict .keys ():
576
+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
556
577
557
578
class _MyQATQuantizer (TwoStepQuantizer ):
558
579
"""
@@ -603,5 +624,127 @@ def test_qat_4w_embedding(self):
603
624
converted = quantizer .convert (model )
604
625
converted_out = converted (* x )
605
626
627
+ def test_fake_quantize_config (self ):
628
+ """
629
+ Test initialization and property setting of `FakeQuantizeConfig`.
630
+ """
631
+ # basic configs
632
+ per_token_config = FakeQuantizeConfig (8 , "per_token" )
633
+ self .assertEqual (per_token_config .bit_width , 8 )
634
+ self .assertEqual (per_token_config .granularity , QuantizationGranularity .PER_TOKEN )
635
+ self .assertIsNone (per_token_config .group_size )
636
+ per_channel_config = FakeQuantizeConfig (4 , "per_channel" )
637
+ self .assertEqual (per_channel_config .bit_width , 4 )
638
+ self .assertEqual (per_channel_config .granularity , QuantizationGranularity .PER_CHANNEL )
639
+ self .assertIsNone (per_channel_config .group_size )
640
+
641
+ # initialize per_group config using only group size
642
+ per_group_config = FakeQuantizeConfig (4 , group_size = 32 )
643
+ self .assertEqual (per_group_config .bit_width , 4 )
644
+ self .assertEqual (per_group_config .granularity , QuantizationGranularity .PER_GROUP )
645
+ self .assertEqual (per_group_config .group_size , 32 )
646
+
647
+ # set granularity after initialization, should accept str as before
648
+ per_group_config .granularity = "per_token"
649
+ self .assertEqual (per_token_config .granularity , QuantizationGranularity .PER_TOKEN )
650
+
651
+ # set group_size after initialization, should also update granularity
652
+ per_group_config .group_size = 16
653
+ self .assertEqual (per_group_config .granularity , QuantizationGranularity .PER_GROUP )
654
+ self .assertEqual (per_group_config .group_size , 16 )
655
+
656
+ # bad config1: no granularity or group size provided
657
+ with self .assertRaisesRegex (ValueError , "group_size or granularity must be set" ):
658
+ FakeQuantizeConfig (8 )
659
+
660
+ # bad config2: 'per_group' but no group size
661
+ with self .assertRaisesRegex (ValueError , "no group_size was set" ):
662
+ FakeQuantizeConfig (8 , "per_group" )
663
+
664
+ # bad config3: group size was set but granularity was not 'per_group'
665
+ with self .assertRaisesRegex (ValueError , "group_size was set" ):
666
+ FakeQuantizeConfig (8 , "per_token" , group_size = 16 )
667
+
668
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
669
+ def test_fake_quantized_linear_8da4w (self ):
670
+ """
671
+ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`.
672
+ """
673
+ group_size = 128
674
+ torch .manual_seed (self .SEED )
675
+ fq_linear = FakeQuantizedLinear (
676
+ 256 ,
677
+ 688 ,
678
+ bias = False ,
679
+ activation_config = FakeQuantizeConfig (8 , "per_token" , symmetric = False ),
680
+ weight_config = FakeQuantizeConfig (4 , group_size = group_size ),
681
+ )
682
+
683
+ def linear_forward_8da4w (x : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
684
+ """
685
+ Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
686
+ """
687
+ # activations
688
+ (s , zp ) = _choose_qparams_per_token_asymmetric (x , torch .float32 , torch .int32 )
689
+ (qmin , qmax ) = _get_qmin_qmax (8 )
690
+ x_fq = _fake_quantize_per_token (x , s , zp , qmin , qmax )
691
+
692
+ # weights
693
+ (s , zp ) = get_group_qparams_symmetric (weight , 4 , group_size , torch .float32 )
694
+ zp = zp .to (torch .int32 )
695
+ (qmin , qmax ) = _get_qmin_qmax (4 )
696
+ w_fq = _fake_quantize_per_channel_group (weight , s , zp , qmin , qmax , group_size )
697
+ return F .linear (x_fq , w_fq )
698
+
699
+ # Compare linear values
700
+ torch .manual_seed (self .SEED )
701
+ x = torch .randn (100 , 256 )
702
+ x2 = copy .deepcopy (x )
703
+ fq_out = fq_linear (x )
704
+ baseline_out = linear_forward_8da4w (x2 , fq_linear .weight )
705
+ torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
706
+
707
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
708
+ def test_fake_quantized_linear_4w (self ):
709
+ """
710
+ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`.
711
+ """
712
+ group_size = 128
713
+ weight_config = FakeQuantizeConfig (
714
+ bit_width = 4 ,
715
+ group_size = group_size ,
716
+ symmetric = False ,
717
+ zero_point_domain = ZeroPointDomain .FLOAT ,
718
+ )
719
+ torch .manual_seed (self .SEED )
720
+ fq_linear = FakeQuantizedLinear (
721
+ 256 ,
722
+ 688 ,
723
+ bias = False ,
724
+ activation_config = None ,
725
+ weight_config = weight_config ,
726
+ )
727
+
728
+ def linear_forward_4w (x : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
729
+ """
730
+ Baseline for int4 weight only fake quantization that simulates the tinygemm kernel.
731
+ """
732
+ (qmin , qmax ) = _get_qmin_qmax (4 , symmetric = False )
733
+ (s , zp ) = get_groupwise_affine_qparams (weight , 4 , group_size , torch .float32 )
734
+ zp = zp .to (torch .int32 )
735
+ w_fq = _fake_quantize_per_channel_group (
736
+ weight , s , zp , qmin , qmax , group_size , zero_point_domain = ZeroPointDomain .FLOAT ,
737
+ )
738
+ return F .linear (x , w_fq )
739
+
740
+ # Compare linear values
741
+ torch .manual_seed (self .SEED )
742
+ x = torch .randn (100 , 256 )
743
+ x2 = copy .deepcopy (x )
744
+ fq_out = fq_linear (x )
745
+ baseline_out = linear_forward_4w (x2 , fq_linear .weight )
746
+ torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
747
+
748
+
606
749
if __name__ == "__main__" :
607
750
unittest .main ()
0 commit comments