Skip to content

Commit 9e0a59f

Browse files
authored
Make module swap the main QAT flow again (#1037)
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
1 parent 49b1fb6 commit 9e0a59f

File tree

6 files changed

+418
-785
lines changed

6 files changed

+418
-785
lines changed

test/quantization/test_qat.py

Lines changed: 26 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,11 @@
1818
from torchao.quantization.prototype.qat.api import (
1919
ComposableQATQuantizer,
2020
)
21-
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
22-
AffineFakeQuantizedTensor,
23-
)
2421
from torchao.quantization.prototype.qat.utils import (
2522
_choose_qparams_per_token_asymmetric,
2623
_fake_quantize_per_channel_group,
2724
_fake_quantize_per_token,
2825
_GenericFakeQuantize,
29-
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
3026
)
3127
from torchao.quantization.quant_api import (
3228
int4_weight_only,
@@ -164,7 +160,7 @@ def _set_ptq_weight(
164160
Int8DynActInt4WeightLinear,
165161
WeightOnlyInt4Linear,
166162
)
167-
from torchao.quantization.prototype.qat._module_swap_api import (
163+
from torchao.quantization.prototype.qat.linear import (
168164
Int8DynActInt4WeightQATLinear,
169165
Int4WeightOnlyQATLinear,
170166
)
@@ -196,7 +192,7 @@ def _set_ptq_weight(
196192

197193
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
198194
def test_qat_8da4w_linear(self):
199-
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear
195+
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
200196
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
201197

202198
group_size = 128
@@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self):
219215
ptq_out = ptq_linear(x2)
220216
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
221217

222-
# TODO: compare against quantize_ API instead
223218
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
224219
def test_qat_8da4w_quantizer(self):
225220
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
226-
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
227-
228-
group_size = 16
229-
torch.manual_seed(self.SEED)
230-
m = M()
231-
m2 = copy.deepcopy(m)
232-
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
233-
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
234-
qat_model = qat_quantizer.prepare(m)
235-
ptq_model = ptq_quantizer.quantize(m2)
236-
237-
# Compare model values
238-
torch.manual_seed(self.SEED)
239-
x = m.example_inputs()
240-
x2 = copy.deepcopy(x)
241-
qat_out = qat_model(*x)
242-
ptq_out = ptq_model(*x2)
243-
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
244-
245-
# Convert QAT model and compare model values
246-
converted_model = qat_quantizer.convert(qat_model)
247-
converted_out = converted_model(*x)
248-
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)
249-
250-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
251-
def test_qat_8da4w_quantizer_module_swap(self):
252-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
253-
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
221+
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
254222

255223
group_size = 16
256224
torch.manual_seed(self.SEED)
257225
m = M()
258226
m2 = copy.deepcopy(m)
259227
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
260-
module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size)
228+
module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
261229
subclass_model = subclass_quantizer.prepare(m)
262230
module_swap_model = module_swap_quantizer.prepare(m2)
263231

@@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self):
288256
qat_model = qat_quantizer.prepare(m)
289257
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))
290258

291-
def _copy_subclass_weights(
292-
self,
293-
nn_linear: torch.nn.Linear,
294-
subclass_linear: AffineFakeQuantizedTensor,
295-
):
296-
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)
297-
298-
def _assert_matches_subclass_weights(
299-
self,
300-
nn_linear: torch.nn.Linear,
301-
subclass_linear: AffineFakeQuantizedTensor,
302-
):
303-
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)
304-
305259
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
306260
def test_qat_8da4w_quantizer_disable_fake_quant(self):
307261
"""
@@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
313267
enable_8da4w_fake_quant,
314268
)
315269

316-
def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
317-
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
318-
self.assertEqual(m.weight.fake_quant_enabled, enabled)
319-
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
320-
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
321-
if enabled:
322-
self.assertIsNotNone(handle)
323-
else:
324-
self.assertIsNone(handle)
325-
326270
group_size = 16
327271
torch.manual_seed(self.SEED)
328272
m = M()
@@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
331275
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
332276
qat_model = quantizer.prepare(m)
333277
qat_model.apply(disable_8da4w_fake_quant)
334-
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
335-
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
336-
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)
278+
self.assertFalse(qat_model.linear1._fake_quant_enabled)
279+
self.assertFalse(qat_model.linear2._fake_quant_enabled)
280+
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
337281

338282
# Disabled fake quant is just a normal linear
339-
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
340-
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
341-
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
283+
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
284+
m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight)
285+
m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight)
342286
torch.manual_seed(self.SEED)
343287
x = m.example_inputs()
344288
x2 = copy.deepcopy(x)
@@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
348292

349293
# Renable fake quant
350294
qat_model.apply(enable_8da4w_fake_quant)
351-
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
352-
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
353-
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)
295+
self.assertTrue(qat_model.linear1._fake_quant_enabled)
296+
self.assertTrue(qat_model.linear2._fake_quant_enabled)
297+
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
354298

355299
# Fake quant should be applied as normal
356300
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
357301
qat_model2 = quantizer2.prepare(m3)
358-
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
359-
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
360-
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
302+
qat_model2.linear1.weight = qat_model.linear1.weight
303+
qat_model2.linear2.weight = qat_model.linear2.weight
304+
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
361305
torch.manual_seed(self.SEED)
362306
x = m.example_inputs()
363307
x2 = copy.deepcopy(x)
@@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
382326
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
383327
qat_model = quantizer.prepare(m)
384328
qat_model.apply(disable_8da4w_fake_quant)
385-
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
386-
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
387-
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
329+
nn_model.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
330+
nn_model.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight)
331+
nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight)
388332

389333
# Simulate training for both models
390334
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
@@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
406350
optimizer2.step()
407351

408352
# After 1 training step, weights should match exactly
409-
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
410-
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
411-
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
353+
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
354+
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
355+
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
412356

413357
def _test_qat_quantized_gradients(self, quantizer):
414358
"""
@@ -542,7 +486,7 @@ def test_qat_4w_primitives(self):
542486
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
543487
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
544488
def test_qat_4w_linear(self):
545-
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear
489+
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
546490
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
547491

548492
group_size = 128
@@ -567,39 +511,6 @@ def test_qat_4w_linear(self):
567511
ptq_out = ptq_linear(x2)
568512
self._assert_close_4w(qat_out, ptq_out)
569513

570-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
571-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
572-
def test_qat_4w_quantizer(self):
573-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
574-
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
575-
576-
group_size = 32
577-
inner_k_tiles = 8
578-
device = torch.device("cuda")
579-
dtype = torch.bfloat16
580-
torch.manual_seed(self.SEED)
581-
m = M().to(device).to(dtype)
582-
m2 = copy.deepcopy(m)
583-
qat_quantizer = Int4WeightOnlyQATQuantizer(
584-
groupsize=group_size, inner_k_tiles=inner_k_tiles,
585-
)
586-
qat_model = qat_quantizer.prepare(m)
587-
ptq_model = m2
588-
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))
589-
590-
# Compare model values
591-
torch.manual_seed(self.SEED)
592-
x = [i.to(device).to(dtype) for i in m.example_inputs()]
593-
x2 = copy.deepcopy(x)
594-
qat_out = qat_model(*x)
595-
ptq_out = ptq_model(*x2)
596-
self._assert_close_4w(qat_out, ptq_out)
597-
598-
# Convert QAT model and compare model values
599-
converted_model = qat_quantizer.convert(qat_model)
600-
converted_out = converted_model(*x)
601-
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)
602-
603514
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
604515
def test_qat_4w_quantizer_gradients(self):
605516
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
@@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self):
608519

609520
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
610521
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
611-
def test_qat_4w_quantizer_module_swap(self):
522+
def test_qat_4w_quantizer(self):
612523
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
613-
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap
524+
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer
614525

615526
group_size = 32
616527
inner_k_tiles = 8
@@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self):
622533
subclass_quantizer = Int4WeightOnlyQATQuantizer(
623534
groupsize=group_size, inner_k_tiles=inner_k_tiles,
624535
)
625-
module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap(
536+
module_swap_quantizer = Int4WeightOnlyQATQuantizer(
626537
groupsize=group_size, inner_k_tiles=inner_k_tiles,
627538
)
628539
subclass_model = subclass_quantizer.prepare(m)

torchao/quantization/prototype/qat/__init__.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from .api import (
2+
ComposableQATQuantizer,
3+
)
4+
from .linear import (
25
disable_4w_fake_quant,
36
disable_8da4w_fake_quant,
47
enable_4w_fake_quant,
58
enable_8da4w_fake_quant,
6-
int4_weight_only_fake_quantize,
7-
int8_dynamic_activation_int4_weight_fake_quantize,
8-
ComposableQATQuantizer,
99
Int4WeightOnlyQATQuantizer,
10-
Int8DynActInt4WeightQATQuantizer,
11-
)
12-
13-
from ._module_swap_api import (
1410
Int8DynActInt4WeightQATLinear,
11+
Int8DynActInt4WeightQATQuantizer,
1512
)
1613
from .embedding import (
1714
Int4WeightOnlyEmbeddingQATQuantizer,
@@ -22,8 +19,6 @@
2219
"disable_8da4w_fake_quant",
2320
"enable_4w_fake_quant",
2421
"enable_8da4w_fake_quant",
25-
"int4_weight_only_fake_quantize",
26-
"int8_dynamic_activation_int4_weight_fake_quantize",
2722
"ComposableQATQuantizer",
2823
"Int4WeightOnlyQATQuantizer",
2924
"Int4WeightOnlyEmbeddingQATQuantizer"

0 commit comments

Comments
 (0)