Skip to content

Commit dd21c05

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add support for Int8DynActInt4WeightQuantizer (pytorch#66)
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ pytorch#66 att Pull Request resolved: pytorch#66 Test Plan: python test/quantization/test_quant_api.py -k test_8da4w_quantizer Reviewed By: cpuhrsch Differential Revision: D55101038 Pulled By: jerryzh168 fbshipit-source-id: 79d51fcfc758c73acdc149fbe155b6b75fe3276e
1 parent 2cca395 commit dd21c05

File tree

3 files changed

+107
-17
lines changed

3 files changed

+107
-17
lines changed

test/quantization/test_quant_api.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
Quantizer,
2525
TwoStepQuantizer,
2626
Int8DynActInt4WeightGPTQQuantizer,
27+
Int8DynActInt4WeightQuantizer,
28+
Int8DynActInt4WeightLinear,
2729
)
2830
from pathlib import Path
2931
from sentencepiece import SentencePieceProcessor
@@ -85,8 +87,8 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8587
class M(torch.nn.Module):
8688
def __init__(self):
8789
super().__init__()
88-
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
89-
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
90+
self.linear1 = torch.nn.Linear(64, 32).to(torch.float)
91+
self.linear2 = torch.nn.Linear(32, 64).to(torch.float)
9092

9193
def forward(self, x):
9294
x = self.linear1(x)
@@ -132,8 +134,15 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
132134
compiled = m(*example_inputs)
133135
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
134136

137+
def test_8da4w_quantizer(self):
138+
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
139+
m = M().eval()
140+
m = quantizer.quantize(m)
141+
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
142+
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
143+
135144
@unittest.skip("skipping until we get checkpoints for gpt-fast")
136-
def test_gptq(self):
145+
def test_gptq_quantizer(self):
137146
# should be similar to TorchCompileDynamicQuantizer
138147
precision = torch.bfloat16
139148
device = "cpu"

torchao/quantization/GPTQ.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn.functional as F
1818
# from model import Transformer
1919
from torch.utils._pytree import tree_flatten, tree_unflatten
20+
import logging
2021

2122
# pyre-fixme[5]: Global expression must be annotated.
2223
aten = torch.ops.aten
@@ -61,7 +62,7 @@ def model_forward(model, x, input_pos):
6162
get_task_dict = tasks.get_task_dict
6263
evaluate = evaluator.evaluate
6364
else:
64-
print("lm_eval is not installed, GPTQ may not be usable")
65+
logging.info("lm_eval is not installed, GPTQ may not be usable")
6566

6667
# pyre-fixme[3]: Return type must be annotated.
6768
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

torchao/quantization/quant_api.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
from .quant_primitives import (
3434
get_group_qparams_symmetric,
3535
per_token_dynamic_quant,
36+
group_quantize_tensor_symmetric,
3637
)
37-
from typing import Dict, Tuple
38+
from typing import Dict, Tuple, Any
39+
import logging
3840

3941
__all__ = [
4042
"apply_weight_only_int8_quant",
@@ -50,21 +52,18 @@
5052
############################# Unified Quantization APIs ##############################
5153
# API 1, single quantize call to create a quantized model with quantized state_dict
5254
class Quantizer:
53-
# pyre-fixme[2]: Parameter must be annotated.
54-
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
55+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
5556
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
5657
pass
5758

5859

5960
# API 2, flow that needs calibration or training
6061
class TwoStepQuantizer:
61-
# pyre-fixme[2]: Parameter must be annotated.
62-
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
62+
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
6363
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6464
pass
6565

66-
# pyre-fixme[2]: Parameter must be annotated.
67-
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
66+
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
6867
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6968
pass
7069

@@ -250,7 +249,7 @@ def replace_conv2d_1x1(conv):
250249
MultiInput,
251250
)
252251
else:
253-
print("lm_eval not available, skip defining GPTQQuantizer")
252+
logging.info("lm_eval not available, skip defining GPTQQuantizer")
254253

255254

256255
class GPTQQuantizer(Quantizer):
@@ -432,11 +431,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
432431

433432
@torch.no_grad()
434433
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
435-
def quantize(
436-
self,
437-
# pyre-fixme[2]: Parameter must be annotated.
438-
model,
439-
) -> torch.nn.Module:
434+
def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module:
440435
state_dict = self._create_quantized_state_dict(
441436
model,
442437
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
@@ -670,6 +665,91 @@ def replace_linear_8da4w(
670665
)
671666

672667

668+
class Int8DynActInt4WeightQuantizer(Quantizer):
669+
def __init__(
670+
self,
671+
group_size: int = 256,
672+
padding_allowed: bool = False,
673+
precision: torch.dtype = torch.float32,
674+
scales_precision: torch.dtype = torch.float32,
675+
) -> None:
676+
self.group_size: int = group_size
677+
self.padding_allowed: bool = padding_allowed
678+
self.precision: torch.dtype = precision
679+
self.scales_precision: torch.dtype = scales_precision
680+
# assert group_size in [32, 64, 128, 256]
681+
682+
@torch.no_grad()
683+
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
684+
cur_state_dict = model.state_dict()
685+
for fqn, mod in model.named_modules():
686+
if isinstance(mod, torch.nn.Linear):
687+
assert mod.bias is not None
688+
out_features = mod.out_features
689+
in_features = mod.in_features
690+
# assert out_features % 8 == 0, "require out_features % 8 == 0"
691+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
692+
693+
assert (
694+
in_features % self.group_size == 0
695+
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
696+
697+
weight = mod.weight.data
698+
"""
699+
if not _check_linear_int4_k(
700+
in_features, self.group_size
701+
):
702+
if self.padding_allowed:
703+
print(
704+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
705+
)
706+
padded_in_features = _calc_padded_size_linear_int4(
707+
in_features, self.group_size
708+
)
709+
weight = F.pad(
710+
weight, pad=(0, padded_in_features - in_features)
711+
)
712+
else:
713+
raise RuntimeError(
714+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
715+
+ "and that group_size"
716+
)
717+
"""
718+
(
719+
weight_int8,
720+
scales,
721+
zeros,
722+
) = group_quantize_tensor_symmetric(
723+
weight.to(self.precision),
724+
4, # n_bit
725+
self.group_size,
726+
self.scales_precision,
727+
)
728+
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
729+
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
730+
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
731+
# TODO: support bias?
732+
733+
return cur_state_dict
734+
735+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
736+
replace_linear_8da4w(
737+
model,
738+
self.group_size,
739+
self.padding_allowed,
740+
self.precision,
741+
self.scales_precision,
742+
)
743+
return model
744+
745+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
746+
state_dict = self._create_quantized_state_dict(model)
747+
model = self._convert_for_runtime(model)
748+
# TODO: make it strict
749+
model.load_state_dict(state_dict, strict=False)
750+
return model
751+
752+
673753
class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
674754
# pyre-fixme[3]: Return type must be annotated.
675755
def __init__(

0 commit comments

Comments
 (0)