Skip to content

Make module swap the main QAT flow again #1037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 26 additions & 115 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,11 @@
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
)
from torchao.quantization.quant_api import (
int4_weight_only,
Expand Down Expand Up @@ -164,7 +160,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat._module_swap_api import (
from torchao.quantization.prototype.qat.linear import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -196,7 +192,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand All @@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_module_swap(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size)
module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)

Expand Down Expand Up @@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

def _copy_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)

def _assert_matches_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand All @@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
enable_8da4w_fake_quant,
)

def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
self.assertEqual(m.weight.fake_quant_enabled, enabled)
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
if enabled:
self.assertIsNotNone(handle)
else:
self.assertIsNone(handle)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
Expand All @@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)

# Disabled fake quant is just a normal linear
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight)
m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight)
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
nn_model.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
nn_model.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight)
nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight)

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
optimizer2.step()

# After 1 training step, weights should match exactly
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)

def _test_qat_quantized_gradients(self, quantizer):
"""
Expand Down Expand Up @@ -542,7 +486,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand All @@ -567,39 +511,6 @@ def test_qat_4w_linear(self):
ptq_out = ptq_linear(x2)
self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = m2
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
self._assert_close_4w(qat_out, ptq_out)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
Expand All @@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer_module_swap(self):
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer

group_size = 32
inner_k_tiles = 8
Expand All @@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self):
subclass_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap(
module_swap_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
subclass_model = subclass_quantizer.prepare(m)
Expand Down
13 changes: 4 additions & 9 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from .api import (
ComposableQATQuantizer,
)
from .linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
ComposableQATQuantizer,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

from ._module_swap_api import (
Int8DynActInt4WeightQATLinear,
Int8DynActInt4WeightQATQuantizer,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
Expand All @@ -22,8 +19,6 @@
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyEmbeddingQATQuantizer"
Expand Down
Loading
Loading