Skip to content

Move quantization/prototype -> prototype/quantization #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)

```python
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
from torchao.quantization.utils import compute_error
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
from torchao.prototype.quantization.mixed_precision.scripts.naive_intNwo import intN_weight_only

_CUDA_IS_AVAILABLE = torch.cuda.is_available()

Expand Down
30 changes: 15 additions & 15 deletions test/quantization/test_qat.py → test/prototype/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
PerRow,
PerToken,
)
from torchao.quantization.prototype.qat.api import (
from torchao.prototype.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
from torchao.prototype.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.prototype.quantization.qat.linear import (
FakeQuantizedLinear,
)
from torchao.quantization.prototype.qat.utils import (
from torchao.prototype.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
Expand Down Expand Up @@ -172,7 +172,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat.linear import (
from torchao.prototype.quantization.qat.linear import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -204,7 +204,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.prototype.quantization.qat.linear import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand All @@ -229,7 +229,7 @@ def test_qat_8da4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_qat_8da4w_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer

with torch.device("meta"):
m = M()
Expand All @@ -278,7 +278,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
from torchao.prototype.quantization.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
from torchao.prototype.quantization.qat import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)
Expand Down Expand Up @@ -419,7 +419,7 @@ def _test_qat_quantized_gradients(self, quantizer):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

Expand Down Expand Up @@ -509,7 +509,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
from torchao.prototype.quantization.qat.linear import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand All @@ -536,14 +536,14 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_gradients(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_composable_qat_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
from torchao.prototype.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ For example, on a single GPU:
```python
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer

# Smaller version of llama3 to fit in a single GPU
model = llama3(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

Expand Down Expand Up @@ -88,7 +88,7 @@ def forward(
input: torch.Tensor,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
assert isinstance(input, AffineFakeQuantizedTensor)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter(

def _is_linear(mod, *args):
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

Expand Down
Loading