@@ -201,7 +201,10 @@ def test_qat_8da4w_quantizer(self):
201
201
torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
202
202
203
203
@unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
204
- def test_qat_8da4w_quantizer_enable_fake_quant (self ):
204
+ def test_qat_8da4w_quantizer_disable_fake_quant (self ):
205
+ """
206
+ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
207
+ """
205
208
from torchao .quantization .prototype .qat import (
206
209
Int8DynActInt4WeightQATQuantizer ,
207
210
disable_8da4w_fake_quant ,
@@ -250,6 +253,51 @@ def test_qat_8da4w_quantizer_enable_fake_quant(self):
250
253
qat_out2 = qat_model2 (* x2 )
251
254
torch .testing .assert_close (qat_out , qat_out2 , atol = 0 , rtol = 0 )
252
255
256
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
257
+ def test_qat_8da4w_quantizer_disable_fake_quant_backward (self ):
258
+ """
259
+ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
260
+ """
261
+ from torchao .quantization .prototype .qat import (
262
+ Int8DynActInt4WeightQATQuantizer ,
263
+ disable_8da4w_fake_quant ,
264
+ )
265
+
266
+ group_size = 16
267
+ torch .manual_seed (self .SEED )
268
+ m = M ()
269
+ nn_model = copy .deepcopy (m )
270
+ quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
271
+ qat_model = quantizer .prepare (m )
272
+ qat_model .apply (disable_8da4w_fake_quant )
273
+ nn_model .linear1 .weight = qat_model .linear1 .weight
274
+ nn_model .linear2 .weight = qat_model .linear2 .weight
275
+ nn_model .sub .linear .weight = qat_model .sub .linear .weight
276
+
277
+ # Simulate training for both models
278
+ optimizer1 = torch .optim .SGD (nn_model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
279
+ optimizer2 = torch .optim .SGD (qat_model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
280
+ loss_fn1 = torch .nn .CrossEntropyLoss ()
281
+ loss_fn2 = torch .nn .CrossEntropyLoss ()
282
+ example_inputs = nn_model .example_inputs ()
283
+ target = torch .randn (1 , 64 ).float ()
284
+ output1 = nn_model (* example_inputs )
285
+ output2 = qat_model (* example_inputs )
286
+ torch .testing .assert_close (output1 , output2 , atol = 0 , rtol = 0 )
287
+ loss1 = loss_fn1 (output1 , target )
288
+ loss2 = loss_fn2 (output2 , target )
289
+ optimizer1 .zero_grad ()
290
+ optimizer2 .zero_grad ()
291
+ loss1 .backward ()
292
+ loss2 .backward ()
293
+ optimizer1 .step ()
294
+ optimizer2 .step ()
295
+
296
+ # After 1 training step, weights should match exactly
297
+ torch .testing .assert_close (nn_model .linear1 .weight , qat_model .linear1 .weight , atol = 0 , rtol = 0 )
298
+ torch .testing .assert_close (nn_model .linear2 .weight , qat_model .linear2 .weight , atol = 0 , rtol = 0 )
299
+ torch .testing .assert_close (nn_model .sub .linear .weight , qat_model .sub .linear .weight , atol = 0 , rtol = 0 )
300
+
253
301
254
302
if __name__ == "__main__" :
255
303
unittest .main ()
0 commit comments