Skip to content

Commit b6761ac

Browse files
committed
[quant] Add support for Int8DynActInt4WeightQuantizer
Summary: att Test Plan: python test/quantization/test_quant_api.py -k test_8da4w_quantizer Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 364956f Pull Request resolved: #66
1 parent 9c048eb commit b6761ac

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

test/quantization/test_quant_api.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
Quantizer,
2525
TwoStepQuantizer,
2626
Int8DynActInt4WeightGPTQQuantizer,
27+
Int8DynActInt4WeightQuantizer,
28+
Int8DynActInt4WeightLinear,
2729
)
2830
from pathlib import Path
2931
from sentencepiece import SentencePieceProcessor
@@ -85,8 +87,8 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8587
class M(torch.nn.Module):
8688
def __init__(self):
8789
super().__init__()
88-
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
89-
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
90+
self.linear1 = torch.nn.Linear(64, 32).to(torch.float)
91+
self.linear2 = torch.nn.Linear(32, 64).to(torch.float)
9092

9193
def forward(self, x):
9294
x = self.linear1(x)
@@ -132,8 +134,15 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
132134
compiled = m(*example_inputs)
133135
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
134136

137+
def test_8da4w_quantizer(self):
138+
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
139+
m = M().eval()
140+
m = quantizer.quantize(m)
141+
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
142+
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
143+
135144
@unittest.skip("skipping until we get checkpoints for gpt-fast")
136-
def test_gptq(self):
145+
def test_gptq_quantizer(self):
137146
# should be similar to TorchCompileDynamicQuantizer
138147
precision = torch.bfloat16
139148
device = "cpu"

torchao/quantization/GPTQ.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn.functional as F
1818
# from model import Transformer
1919
from torch.utils._pytree import tree_flatten, tree_unflatten
20+
import logging
2021

2122
# pyre-fixme[5]: Global expression must be annotated.
2223
aten = torch.ops.aten
@@ -61,7 +62,7 @@ def model_forward(model, x, input_pos):
6162
get_task_dict = tasks.get_task_dict
6263
evaluate = evaluator.evaluate
6364
else:
64-
print("lm_eval is not installed, GPTQ may not be usable")
65+
logging.info("lm_eval is not installed, GPTQ may not be usable")
6566

6667
# pyre-fixme[3]: Return type must be annotated.
6768
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

torchao/quantization/quant_api.py

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
from .quant_primitives import (
3434
get_group_qparams_symmetric,
3535
per_token_dynamic_quant,
36+
group_quantize_tensor_symmetric,
3637
)
37-
from typing import Dict, Tuple
38+
from typing import Dict, Tuple, Any
39+
import logging
3840

3941
__all__ = [
4042
"apply_weight_only_int8_quant",
@@ -50,21 +52,19 @@
5052
############################# Unified Quantization APIs ##############################
5153
# API 1, single quantize call to create a quantized model with quantized state_dict
5254
class Quantizer:
53-
# pyre-fixme[2]: Parameter must be annotated.
54-
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
55+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
5556
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
5657
pass
5758

5859

5960
# API 2, flow that needs calibration or training
6061
class TwoStepQuantizer:
61-
# pyre-fixme[2]: Parameter must be annotated.
62-
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
62+
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
6363
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6464
pass
6565

6666
# pyre-fixme[2]: Parameter must be annotated.
67-
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
67+
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
6868
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6969
pass
7070

@@ -250,7 +250,7 @@ def replace_conv2d_1x1(conv):
250250
MultiInput,
251251
)
252252
else:
253-
print("lm_eval not available, skip defining GPTQQuantizer")
253+
logging.info("lm_eval not available, skip defining GPTQQuantizer")
254254

255255

256256
class GPTQQuantizer(Quantizer):
@@ -432,11 +432,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
432432

433433
@torch.no_grad()
434434
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
435-
def quantize(
436-
self,
437-
# pyre-fixme[2]: Parameter must be annotated.
438-
model,
439-
) -> torch.nn.Module:
435+
def quantize(self, model: torch.nn.Module, **kwargs) -> torch.nn.Module:
440436
state_dict = self._create_quantized_state_dict(
441437
model,
442438
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
@@ -670,6 +666,91 @@ def replace_linear_8da4w(
670666
)
671667

672668

669+
class Int8DynActInt4WeightQuantizer(Quantizer):
670+
def __init__(
671+
self,
672+
group_size: int = 256,
673+
padding_allowed: bool = False,
674+
precision: torch.dtype = torch.float32,
675+
scales_precision: torch.dtype = torch.float32,
676+
) -> None:
677+
self.group_size: int = group_size
678+
self.padding_allowed: bool = padding_allowed
679+
self.precision: torch.dtype = precision
680+
self.scales_precision: torch.dtype = scales_precision
681+
# assert group_size in [32, 64, 128, 256]
682+
683+
@torch.no_grad()
684+
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
685+
cur_state_dict = model.state_dict()
686+
for fqn, mod in model.named_modules():
687+
if isinstance(mod, torch.nn.Linear):
688+
assert mod.bias is not None
689+
out_features = mod.out_features
690+
in_features = mod.in_features
691+
# assert out_features % 8 == 0, "require out_features % 8 == 0"
692+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
693+
694+
assert (
695+
in_features % self.group_size == 0
696+
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
697+
698+
weight = mod.weight.data
699+
"""
700+
if not _check_linear_int4_k(
701+
in_features, self.group_size
702+
):
703+
if self.padding_allowed:
704+
print(
705+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
706+
)
707+
padded_in_features = _calc_padded_size_linear_int4(
708+
in_features, self.group_size
709+
)
710+
weight = F.pad(
711+
weight, pad=(0, padded_in_features - in_features)
712+
)
713+
else:
714+
raise RuntimeError(
715+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
716+
+ "and that group_size"
717+
)
718+
"""
719+
(
720+
weight_int8,
721+
scales,
722+
zeros,
723+
) = group_quantize_tensor_symmetric(
724+
weight.to(self.precision),
725+
4, # n_bit
726+
self.group_size,
727+
self.scales_precision,
728+
)
729+
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
730+
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
731+
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
732+
# TODO: support bias?
733+
734+
return cur_state_dict
735+
736+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
737+
replace_linear_8da4w(
738+
model,
739+
self.group_size,
740+
self.padding_allowed,
741+
self.precision,
742+
self.scales_precision,
743+
)
744+
return model
745+
746+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
747+
state_dict = self._create_quantized_state_dict(model)
748+
model = self._convert_for_runtime(model)
749+
# TODO: make it strict
750+
model.load_state_dict(state_dict, strict=False)
751+
return model
752+
753+
673754
class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
674755
# pyre-fixme[3]: Return type must be annotated.
675756
def __init__(

0 commit comments

Comments
 (0)