Skip to content

Commit c0aa309

Browse files
committed
Revert "Add support for Int8DynActInt4WeightQuantizer (#66)"
This reverts commit 12b9194.
1 parent 12b9194 commit c0aa309

File tree

3 files changed

+17
-107
lines changed

3 files changed

+17
-107
lines changed

test/quantization/test_quant_api.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
Quantizer,
2525
TwoStepQuantizer,
2626
Int8DynActInt4WeightGPTQQuantizer,
27-
Int8DynActInt4WeightQuantizer,
28-
Int8DynActInt4WeightLinear,
2927
)
3028
from pathlib import Path
3129
from sentencepiece import SentencePieceProcessor
@@ -87,8 +85,8 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8785
class M(torch.nn.Module):
8886
def __init__(self):
8987
super().__init__()
90-
self.linear1 = torch.nn.Linear(64, 32).to(torch.float)
91-
self.linear2 = torch.nn.Linear(32, 64).to(torch.float)
88+
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
89+
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
9290

9391
def forward(self, x):
9492
x = self.linear1(x)
@@ -134,15 +132,8 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
134132
compiled = m(*example_inputs)
135133
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
136134

137-
def test_8da4w_quantizer(self):
138-
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
139-
m = M().eval()
140-
m = quantizer.quantize(m)
141-
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
142-
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
143-
144135
@unittest.skip("skipping until we get checkpoints for gpt-fast")
145-
def test_gptq_quantizer(self):
136+
def test_gptq(self):
146137
# should be similar to TorchCompileDynamicQuantizer
147138
precision = torch.bfloat16
148139
device = "cpu"

torchao/quantization/GPTQ.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch.nn.functional as F
1818
# from model import Transformer
1919
from torch.utils._pytree import tree_flatten, tree_unflatten
20-
import logging
2120

2221
# pyre-fixme[5]: Global expression must be annotated.
2322
aten = torch.ops.aten
@@ -62,7 +61,7 @@ def model_forward(model, x, input_pos):
6261
get_task_dict = tasks.get_task_dict
6362
evaluate = evaluator.evaluate
6463
else:
65-
logging.info("lm_eval is not installed, GPTQ may not be usable")
64+
print("lm_eval is not installed, GPTQ may not be usable")
6665

6766
# pyre-fixme[3]: Return type must be annotated.
6867
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

torchao/quantization/quant_api.py

Lines changed: 13 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
from .quant_primitives import (
3434
get_group_qparams_symmetric,
3535
per_token_dynamic_quant,
36-
group_quantize_tensor_symmetric,
3736
)
38-
from typing import Dict, Tuple, Any
39-
import logging
37+
from typing import Dict, Tuple
4038

4139
__all__ = [
4240
"apply_weight_only_int8_quant",
@@ -52,18 +50,21 @@
5250
############################# Unified Quantization APIs ##############################
5351
# API 1, single quantize call to create a quantized model with quantized state_dict
5452
class Quantizer:
55-
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
53+
# pyre-fixme[2]: Parameter must be annotated.
54+
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
5655
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
5756
pass
5857

5958

6059
# API 2, flow that needs calibration or training
6160
class TwoStepQuantizer:
62-
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
61+
# pyre-fixme[2]: Parameter must be annotated.
62+
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
6363
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6464
pass
6565

66-
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
66+
# pyre-fixme[2]: Parameter must be annotated.
67+
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
6768
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6869
pass
6970

@@ -249,7 +250,7 @@ def replace_conv2d_1x1(conv):
249250
MultiInput,
250251
)
251252
else:
252-
logging.info("lm_eval not available, skip defining GPTQQuantizer")
253+
print("lm_eval not available, skip defining GPTQQuantizer")
253254

254255

255256
class GPTQQuantizer(Quantizer):
@@ -431,7 +432,11 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
431432

432433
@torch.no_grad()
433434
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
434-
def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module:
435+
def quantize(
436+
self,
437+
# pyre-fixme[2]: Parameter must be annotated.
438+
model,
439+
) -> torch.nn.Module:
435440
state_dict = self._create_quantized_state_dict(
436441
model,
437442
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
@@ -665,91 +670,6 @@ def replace_linear_8da4w(
665670
)
666671

667672

668-
class Int8DynActInt4WeightQuantizer(Quantizer):
669-
def __init__(
670-
self,
671-
group_size: int = 256,
672-
padding_allowed: bool = False,
673-
precision: torch.dtype = torch.float32,
674-
scales_precision: torch.dtype = torch.float32,
675-
) -> None:
676-
self.group_size: int = group_size
677-
self.padding_allowed: bool = padding_allowed
678-
self.precision: torch.dtype = precision
679-
self.scales_precision: torch.dtype = scales_precision
680-
# assert group_size in [32, 64, 128, 256]
681-
682-
@torch.no_grad()
683-
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
684-
cur_state_dict = model.state_dict()
685-
for fqn, mod in model.named_modules():
686-
if isinstance(mod, torch.nn.Linear):
687-
assert mod.bias is not None
688-
out_features = mod.out_features
689-
in_features = mod.in_features
690-
# assert out_features % 8 == 0, "require out_features % 8 == 0"
691-
print(f"linear: {fqn}, in={in_features}, out={out_features}")
692-
693-
assert (
694-
in_features % self.group_size == 0
695-
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
696-
697-
weight = mod.weight.data
698-
"""
699-
if not _check_linear_int4_k(
700-
in_features, self.group_size
701-
):
702-
if self.padding_allowed:
703-
print(
704-
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
705-
)
706-
padded_in_features = _calc_padded_size_linear_int4(
707-
in_features, self.group_size
708-
)
709-
weight = F.pad(
710-
weight, pad=(0, padded_in_features - in_features)
711-
)
712-
else:
713-
raise RuntimeError(
714-
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
715-
+ "and that group_size"
716-
)
717-
"""
718-
(
719-
weight_int8,
720-
scales,
721-
zeros,
722-
) = group_quantize_tensor_symmetric(
723-
weight.to(self.precision),
724-
4, # n_bit
725-
self.group_size,
726-
self.scales_precision,
727-
)
728-
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
729-
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
730-
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
731-
# TODO: support bias?
732-
733-
return cur_state_dict
734-
735-
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
736-
replace_linear_8da4w(
737-
model,
738-
self.group_size,
739-
self.padding_allowed,
740-
self.precision,
741-
self.scales_precision,
742-
)
743-
return model
744-
745-
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
746-
state_dict = self._create_quantized_state_dict(model)
747-
model = self._convert_for_runtime(model)
748-
# TODO: make it strict
749-
model.load_state_dict(state_dict, strict=False)
750-
return model
751-
752-
753673
class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
754674
# pyre-fixme[3]: Return type must be annotated.
755675
def __init__(

0 commit comments

Comments
 (0)