18
18
from torchao .quantization .prototype .qat .api import (
19
19
ComposableQATQuantizer ,
20
20
)
21
- from torchao .quantization .prototype .qat .affine_fake_quantized_tensor import (
22
- AffineFakeQuantizedTensor ,
23
- )
24
21
from torchao .quantization .prototype .qat .utils import (
25
22
_choose_qparams_per_token_asymmetric ,
26
23
_fake_quantize_per_channel_group ,
27
24
_fake_quantize_per_token ,
28
25
_GenericFakeQuantize ,
29
- _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ,
30
26
)
31
27
from torchao .quantization .quant_api import (
32
28
int4_weight_only ,
@@ -164,7 +160,7 @@ def _set_ptq_weight(
164
160
Int8DynActInt4WeightLinear ,
165
161
WeightOnlyInt4Linear ,
166
162
)
167
- from torchao .quantization .prototype .qat ._module_swap_api import (
163
+ from torchao .quantization .prototype .qat .linear import (
168
164
Int8DynActInt4WeightQATLinear ,
169
165
Int4WeightOnlyQATLinear ,
170
166
)
@@ -196,7 +192,7 @@ def _set_ptq_weight(
196
192
197
193
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
198
194
def test_qat_8da4w_linear (self ):
199
- from torchao .quantization .prototype .qat ._module_swap_api import Int8DynActInt4WeightQATLinear
195
+ from torchao .quantization .prototype .qat .linear import Int8DynActInt4WeightQATLinear
200
196
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
201
197
202
198
group_size = 128
@@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self):
219
215
ptq_out = ptq_linear (x2 )
220
216
torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
221
217
222
- # TODO: compare against quantize_ API instead
223
218
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
224
219
def test_qat_8da4w_quantizer (self ):
225
220
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
226
- from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
227
-
228
- group_size = 16
229
- torch .manual_seed (self .SEED )
230
- m = M ()
231
- m2 = copy .deepcopy (m )
232
- qat_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
233
- ptq_quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
234
- qat_model = qat_quantizer .prepare (m )
235
- ptq_model = ptq_quantizer .quantize (m2 )
236
-
237
- # Compare model values
238
- torch .manual_seed (self .SEED )
239
- x = m .example_inputs ()
240
- x2 = copy .deepcopy (x )
241
- qat_out = qat_model (* x )
242
- ptq_out = ptq_model (* x2 )
243
- torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
244
-
245
- # Convert QAT model and compare model values
246
- converted_model = qat_quantizer .convert (qat_model )
247
- converted_out = converted_model (* x )
248
- torch .testing .assert_close (ptq_out , converted_out , atol = 0 , rtol = 0 )
249
-
250
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
251
- def test_qat_8da4w_quantizer_module_swap (self ):
252
- from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
253
- from torchao .quantization .prototype .qat ._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
221
+ from torchao .quantization .prototype .qat .linear import Int8DynActInt4WeightQATQuantizer
254
222
255
223
group_size = 16
256
224
torch .manual_seed (self .SEED )
257
225
m = M ()
258
226
m2 = copy .deepcopy (m )
259
227
subclass_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
260
- module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap (groupsize = group_size )
228
+ module_swap_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
261
229
subclass_model = subclass_quantizer .prepare (m )
262
230
module_swap_model = module_swap_quantizer .prepare (m2 )
263
231
@@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self):
288
256
qat_model = qat_quantizer .prepare (m )
289
257
self .assertTrue (all (v .is_meta for v in qat_model .state_dict ().values ()))
290
258
291
- def _copy_subclass_weights (
292
- self ,
293
- nn_linear : torch .nn .Linear ,
294
- subclass_linear : AffineFakeQuantizedTensor ,
295
- ):
296
- nn_linear .weight = torch .nn .Parameter (subclass_linear .weight .original_tensor )
297
-
298
- def _assert_matches_subclass_weights (
299
- self ,
300
- nn_linear : torch .nn .Linear ,
301
- subclass_linear : AffineFakeQuantizedTensor ,
302
- ):
303
- torch .testing .assert_close (nn_linear .weight , subclass_linear .weight .original_tensor , atol = 0 , rtol = 0 )
304
-
305
259
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
306
260
def test_qat_8da4w_quantizer_disable_fake_quant (self ):
307
261
"""
@@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
313
267
enable_8da4w_fake_quant ,
314
268
)
315
269
316
- def assert_fake_quant_enabled (m : torch .nn .Linear , enabled : bool ):
317
- self .assertTrue (isinstance (m .weight , AffineFakeQuantizedTensor ))
318
- self .assertEqual (m .weight .fake_quant_enabled , enabled )
319
- self .assertTrue (hasattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ))
320
- (_ , handle ) = getattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK )
321
- if enabled :
322
- self .assertIsNotNone (handle )
323
- else :
324
- self .assertIsNone (handle )
325
-
326
270
group_size = 16
327
271
torch .manual_seed (self .SEED )
328
272
m = M ()
@@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
331
275
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
332
276
qat_model = quantizer .prepare (m )
333
277
qat_model .apply (disable_8da4w_fake_quant )
334
- assert_fake_quant_enabled (qat_model .linear1 , enabled = False )
335
- assert_fake_quant_enabled (qat_model .linear2 , enabled = False )
336
- assert_fake_quant_enabled (qat_model .sub .linear , enabled = False )
278
+ self . assertFalse (qat_model .linear1 . _fake_quant_enabled )
279
+ self . assertFalse (qat_model .linear2 . _fake_quant_enabled )
280
+ self . assertFalse (qat_model .sub .linear . _fake_quant_enabled )
337
281
338
282
# Disabled fake quant is just a normal linear
339
- self . _copy_subclass_weights ( m2 .linear1 , qat_model .linear1 )
340
- self . _copy_subclass_weights ( m2 .linear2 , qat_model .linear2 )
341
- self . _copy_subclass_weights ( m2 .sub .linear , qat_model .sub .linear )
283
+ m2 .linear1 . weight = torch . nn . Parameter ( qat_model .linear1 . weight )
284
+ m2 .linear2 . weight = torch . nn . Parameter ( qat_model .linear2 . weight )
285
+ m2 .sub .linear . weight = torch . nn . Parameter ( qat_model .sub .linear . weight )
342
286
torch .manual_seed (self .SEED )
343
287
x = m .example_inputs ()
344
288
x2 = copy .deepcopy (x )
@@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
348
292
349
293
# Renable fake quant
350
294
qat_model .apply (enable_8da4w_fake_quant )
351
- assert_fake_quant_enabled (qat_model .linear1 , enabled = True )
352
- assert_fake_quant_enabled (qat_model .linear2 , enabled = True )
353
- assert_fake_quant_enabled (qat_model .sub .linear , enabled = True )
295
+ self . assertTrue (qat_model .linear1 . _fake_quant_enabled )
296
+ self . assertTrue (qat_model .linear2 . _fake_quant_enabled )
297
+ self . assertTrue (qat_model .sub .linear . _fake_quant_enabled )
354
298
355
299
# Fake quant should be applied as normal
356
300
quantizer2 = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
357
301
qat_model2 = quantizer2 .prepare (m3 )
358
- qat_model2 .linear1 .weight . original_tensor = qat_model .linear1 .weight . original_tensor
359
- qat_model2 .linear2 .weight . original_tensor = qat_model .linear2 .weight . original_tensor
360
- qat_model2 .sub .linear .weight . original_tensor = qat_model .sub .linear .weight . original_tensor
302
+ qat_model2 .linear1 .weight = qat_model .linear1 .weight
303
+ qat_model2 .linear2 .weight = qat_model .linear2 .weight
304
+ qat_model2 .sub .linear .weight = qat_model .sub .linear .weight
361
305
torch .manual_seed (self .SEED )
362
306
x = m .example_inputs ()
363
307
x2 = copy .deepcopy (x )
@@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
382
326
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
383
327
qat_model = quantizer .prepare (m )
384
328
qat_model .apply (disable_8da4w_fake_quant )
385
- self . _copy_subclass_weights ( nn_model .linear1 , qat_model .linear1 )
386
- self . _copy_subclass_weights ( nn_model .linear2 , qat_model .linear2 )
387
- self . _copy_subclass_weights ( nn_model .sub .linear , qat_model .sub .linear )
329
+ nn_model .linear1 . weight = torch . nn . Parameter ( qat_model .linear1 . weight )
330
+ nn_model .linear2 . weight = torch . nn . Parameter ( qat_model .linear2 . weight )
331
+ nn_model .sub .linear . weight = torch . nn . Parameter ( qat_model .sub .linear . weight )
388
332
389
333
# Simulate training for both models
390
334
optimizer1 = torch .optim .SGD (nn_model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
@@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
406
350
optimizer2 .step ()
407
351
408
352
# After 1 training step, weights should match exactly
409
- self . _assert_matches_subclass_weights (nn_model .linear1 , qat_model .linear1 )
410
- self . _assert_matches_subclass_weights (nn_model .linear2 , qat_model .linear2 )
411
- self . _assert_matches_subclass_weights (nn_model .sub .linear , qat_model .sub .linear )
353
+ torch . testing . assert_close (nn_model .linear1 . weight , qat_model .linear1 . weight , atol = 0 , rtol = 0 )
354
+ torch . testing . assert_close (nn_model .linear2 . weight , qat_model .linear2 . weight , atol = 0 , rtol = 0 )
355
+ torch . testing . assert_close (nn_model .sub .linear . weight , qat_model .sub .linear . weight , atol = 0 , rtol = 0 )
412
356
413
357
def _test_qat_quantized_gradients (self , quantizer ):
414
358
"""
@@ -542,7 +486,7 @@ def test_qat_4w_primitives(self):
542
486
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
543
487
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
544
488
def test_qat_4w_linear (self ):
545
- from torchao .quantization .prototype .qat ._module_swap_api import Int4WeightOnlyQATLinear
489
+ from torchao .quantization .prototype .qat .linear import Int4WeightOnlyQATLinear
546
490
from torchao .quantization .GPTQ import WeightOnlyInt4Linear
547
491
548
492
group_size = 128
@@ -567,39 +511,6 @@ def test_qat_4w_linear(self):
567
511
ptq_out = ptq_linear (x2 )
568
512
self ._assert_close_4w (qat_out , ptq_out )
569
513
570
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
571
- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
572
- def test_qat_4w_quantizer (self ):
573
- from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
574
- from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
575
-
576
- group_size = 32
577
- inner_k_tiles = 8
578
- device = torch .device ("cuda" )
579
- dtype = torch .bfloat16
580
- torch .manual_seed (self .SEED )
581
- m = M ().to (device ).to (dtype )
582
- m2 = copy .deepcopy (m )
583
- qat_quantizer = Int4WeightOnlyQATQuantizer (
584
- groupsize = group_size , inner_k_tiles = inner_k_tiles ,
585
- )
586
- qat_model = qat_quantizer .prepare (m )
587
- ptq_model = m2
588
- quantize_ (ptq_model , int4_weight_only (group_size , TensorCoreTiledLayoutType (inner_k_tiles )))
589
-
590
- # Compare model values
591
- torch .manual_seed (self .SEED )
592
- x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
593
- x2 = copy .deepcopy (x )
594
- qat_out = qat_model (* x )
595
- ptq_out = ptq_model (* x2 )
596
- self ._assert_close_4w (qat_out , ptq_out )
597
-
598
- # Convert QAT model and compare model values
599
- converted_model = qat_quantizer .convert (qat_model )
600
- converted_out = converted_model (* x )
601
- torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
602
-
603
514
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
604
515
def test_qat_4w_quantizer_gradients (self ):
605
516
from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
@@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self):
608
519
609
520
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
610
521
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
611
- def test_qat_4w_quantizer_module_swap (self ):
522
+ def test_qat_4w_quantizer (self ):
612
523
from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
613
- from torchao .quantization .prototype .qat ._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap
524
+ from torchao .quantization .prototype .qat .linear import Int4WeightOnlyQATQuantizer
614
525
615
526
group_size = 32
616
527
inner_k_tiles = 8
@@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self):
622
533
subclass_quantizer = Int4WeightOnlyQATQuantizer (
623
534
groupsize = group_size , inner_k_tiles = inner_k_tiles ,
624
535
)
625
- module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap (
536
+ module_swap_quantizer = Int4WeightOnlyQATQuantizer (
626
537
groupsize = group_size , inner_k_tiles = inner_k_tiles ,
627
538
)
628
539
subclass_model = subclass_quantizer .prepare (m )
0 commit comments