Skip to content

Commit f44089a

Browse files
authored
Move some quantization ops to pytorch (#77)
Summary: since executorch is removing dep on torchao and only depend on pytorch for now, we need to move these ops to pytorch so they can have access we can move back when torchao is in a more mature state Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 645d654 commit f44089a

File tree

5 files changed

+614
-974
lines changed

5 files changed

+614
-974
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch
22
numpy
33
sentencepiece
4+
packaging

test/quantization/test_quant_api.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from torchao.quantization.quant_api import (
2424
Quantizer,
2525
TwoStepQuantizer,
26-
Int8DynActInt4WeightGPTQQuantizer,
27-
Int8DynActInt4WeightQuantizer,
28-
Int8DynActInt4WeightLinear,
26+
)
27+
from torchao.quantization.utils import (
28+
TORCH_VERSION_AFTER_2_4,
2929
)
3030
from pathlib import Path
3131
from sentencepiece import SentencePieceProcessor
@@ -136,7 +136,11 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
136136
compiled = m(*example_inputs)
137137
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
138138

139+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
139140
def test_8da4w_quantizer(self):
141+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
142+
from torchao.quantization.quant_api import Int8DynActInt4WeightLinear
143+
140144
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
141145
m = M().eval()
142146
example_inputs = m.example_inputs()
@@ -147,6 +151,7 @@ def test_8da4w_quantizer(self):
147151

148152
@unittest.skip("skipping until we get checkpoints for gpt-fast")
149153
def test_gptq_quantizer(self):
154+
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
150155
# should be similar to TorchCompileDynamicQuantizer
151156
precision = torch.bfloat16
152157
device = "cpu"
@@ -163,8 +168,8 @@ def test_gptq_quantizer(self):
163168
blocksize = 128
164169
percdamp = 0.01
165170
groupsize = 128
166-
calibration_tasks = ["hellaswag"]
167-
calibration_limit = 200 # 1000
171+
calibration_tasks = ["wikitext"]
172+
calibration_limit = 5
168173
calibration_seq_length = 100
169174
pad_calibration_inputs = False
170175
quantizer = Int8DynActInt4WeightGPTQQuantizer(

0 commit comments

Comments
 (0)