Skip to content

Commit fbc5742

Browse files
committed
Add backward test
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9de4574 commit fbc5742

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

test/quantization/test_qat.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,10 @@ def test_qat_8da4w_quantizer(self):
201201
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
202202

203203
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
204-
def test_qat_8da4w_quantizer_enable_fake_quant(self):
204+
def test_qat_8da4w_quantizer_disable_fake_quant(self):
205+
"""
206+
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
207+
"""
205208
from torchao.quantization.prototype.qat import (
206209
Int8DynActInt4WeightQATQuantizer,
207210
disable_8da4w_fake_quant,
@@ -250,6 +253,51 @@ def test_qat_8da4w_quantizer_enable_fake_quant(self):
250253
qat_out2 = qat_model2(*x2)
251254
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)
252255

256+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
257+
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
258+
"""
259+
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
260+
"""
261+
from torchao.quantization.prototype.qat import (
262+
Int8DynActInt4WeightQATQuantizer,
263+
disable_8da4w_fake_quant,
264+
)
265+
266+
group_size = 16
267+
torch.manual_seed(self.SEED)
268+
m = M()
269+
nn_model = copy.deepcopy(m)
270+
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
271+
qat_model = quantizer.prepare(m)
272+
qat_model.apply(disable_8da4w_fake_quant)
273+
nn_model.linear1.weight = qat_model.linear1.weight
274+
nn_model.linear2.weight = qat_model.linear2.weight
275+
nn_model.sub.linear.weight = qat_model.sub.linear.weight
276+
277+
# Simulate training for both models
278+
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
279+
optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
280+
loss_fn1 = torch.nn.CrossEntropyLoss()
281+
loss_fn2 = torch.nn.CrossEntropyLoss()
282+
example_inputs = nn_model.example_inputs()
283+
target = torch.randn(1, 64).float()
284+
output1 = nn_model(*example_inputs)
285+
output2 = qat_model(*example_inputs)
286+
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
287+
loss1 = loss_fn1(output1, target)
288+
loss2 = loss_fn2(output2, target)
289+
optimizer1.zero_grad()
290+
optimizer2.zero_grad()
291+
loss1.backward()
292+
loss2.backward()
293+
optimizer1.step()
294+
optimizer2.step()
295+
296+
# After 1 training step, weights should match exactly
297+
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
298+
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
299+
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
300+
253301

254302
if __name__ == "__main__":
255303
unittest.main()

0 commit comments

Comments
 (0)