Skip to content

Commit 5e48e18

Browse files
authored
[reland] Add support for Int8DynActInt4WeightQuantizer (#66) (#74)
Summary: att Test Plan: python test/quantization/test_quant_api.py -k test_8da4w_quantizer Reviewed By: cpuhrsch Differential Revision: D55101038 Pulled By: jerryzh168 [ghstack-poisoned]
1 parent efb6514 commit 5e48e18

File tree

3 files changed

+115
-21
lines changed

3 files changed

+115
-21
lines changed

test/quantization/test_quant_api.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
Quantizer,
2525
TwoStepQuantizer,
2626
Int8DynActInt4WeightGPTQQuantizer,
27+
Int8DynActInt4WeightQuantizer,
28+
Int8DynActInt4WeightLinear,
2729
)
2830
from pathlib import Path
2931
from sentencepiece import SentencePieceProcessor
@@ -85,8 +87,11 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8587
class M(torch.nn.Module):
8688
def __init__(self):
8789
super().__init__()
88-
self.linear1 = torch.nn.Linear(5, 5).to(torch.float)
89-
self.linear2 = torch.nn.Linear(5, 5).to(torch.float)
90+
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
91+
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
92+
93+
def example_inputs(self):
94+
return (torch.randn(1, 64).to(torch.float),)
9095

9196
def forward(self, x):
9297
x = self.linear1(x)
@@ -97,8 +102,7 @@ class TestQuantFlow(unittest.TestCase):
97102
def test_dynamic_quant_gpu_singleline(self):
98103
m = M().eval()
99104
m = _apply_dynamic_quant(m)
100-
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
101-
quantized = m(*example_inputs)
105+
quantized = m(*m.example_inputs())
102106
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
103107
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
104108
# m = torch.compile(m, mode="max-autotune")
@@ -110,9 +114,9 @@ def test_dynamic_quant_gpu_singleline(self):
110114
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
111115
quantizer = XNNPackDynamicQuantizer()
112116
m = M().eval()
117+
example_inputs = m.example_inputs()
113118
m = quantizer.prepare(m)
114119
m = quantizer.convert(m)
115-
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
116120
quantized = m(*example_inputs)
117121
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
118122
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -125,15 +129,24 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
125129
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
126130
quantizer = TorchCompileDynamicQuantizer()
127131
m = M().eval()
132+
example_inputs = m.example_inputs()
128133
m = quantizer.quantize(m)
129-
example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),)
130134
quantized = m(*example_inputs)
131135
m = torch.compile(m, mode="max-autotune")
132136
compiled = m(*example_inputs)
133137
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
134138

139+
def test_8da4w_quantizer(self):
140+
quantizer = Int8DynActInt4WeightQuantizer(group_size=32)
141+
m = M().eval()
142+
example_inputs = m.example_inputs()
143+
m = quantizer.quantize(m)
144+
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
145+
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
146+
m(*example_inputs)
147+
135148
@unittest.skip("skipping until we get checkpoints for gpt-fast")
136-
def test_gptq(self):
149+
def test_gptq_quantizer(self):
137150
# should be similar to TorchCompileDynamicQuantizer
138151
precision = torch.bfloat16
139152
device = "cpu"

torchao/quantization/GPTQ.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# from model import Transformer # pyre-ignore[21]
2121
from torch.utils._pytree import tree_flatten, tree_unflatten
22+
import logging
2223

2324
# pyre-fixme[5]: Global expression must be annotated.
2425
aten = torch.ops.aten
@@ -63,7 +64,7 @@ def model_forward(model, x, input_pos):
6364
get_task_dict = tasks.get_task_dict
6465
evaluate = evaluator.evaluate
6566
else:
66-
print("lm_eval is not installed, GPTQ may not be usable")
67+
logging.info("lm_eval is not installed, GPTQ may not be usable")
6768

6869
# pyre-fixme[3]: Return type must be annotated.
6970
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

torchao/quantization/quant_api.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636
from .quant_primitives import (
3737
get_group_qparams_symmetric,
3838
per_token_dynamic_quant,
39+
group_quantize_tensor_symmetric,
3940
)
40-
from typing import Dict, Tuple
41+
from typing import Dict, Tuple, Any
42+
import logging
4143

4244
__all__ = [
4345
"apply_weight_only_int8_quant",
@@ -54,21 +56,18 @@
5456
############################# Unified Quantization APIs ##############################
5557
# API 1, single quantize call to create a quantized model with quantized state_dict
5658
class Quantizer:
57-
# pyre-fixme[2]: Parameter must be annotated.
58-
def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
59+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
5960
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6061
pass
6162

6263

6364
# API 2, flow that needs calibration or training
6465
class TwoStepQuantizer:
65-
# pyre-fixme[2]: Parameter must be annotated.
66-
def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
66+
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
6767
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
6868
pass
6969

70-
# pyre-fixme[2]: Parameter must be annotated.
71-
def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
70+
def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
7271
# pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
7372
pass
7473

@@ -260,7 +259,7 @@ def replace_conv2d_1x1(conv):
260259
MultiInput,
261260
)
262261
else:
263-
print("lm_eval not available, skip defining GPTQQuantizer")
262+
logging.info("lm_eval not available, skip defining GPTQQuantizer")
264263

265264

266265
class GPTQQuantizer(Quantizer):
@@ -442,11 +441,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
442441

443442
@torch.no_grad()
444443
# pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
445-
def quantize(
446-
self,
447-
# pyre-fixme[2]: Parameter must be annotated.
448-
model,
449-
) -> torch.nn.Module:
444+
def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module:
450445
state_dict = self._create_quantized_state_dict(
451446
model,
452447
# pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
@@ -686,6 +681,91 @@ def replace_linear_8da4w(
686681
)
687682

688683

684+
class Int8DynActInt4WeightQuantizer(Quantizer):
685+
def __init__(
686+
self,
687+
group_size: int = 256,
688+
padding_allowed: bool = False,
689+
precision: torch.dtype = torch.float32,
690+
scales_precision: torch.dtype = torch.float32,
691+
) -> None:
692+
self.group_size: int = group_size
693+
self.padding_allowed: bool = padding_allowed
694+
self.precision: torch.dtype = precision
695+
self.scales_precision: torch.dtype = scales_precision
696+
# assert group_size in [32, 64, 128, 256]
697+
698+
@torch.no_grad()
699+
def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
700+
cur_state_dict = model.state_dict()
701+
for fqn, mod in model.named_modules():
702+
if isinstance(mod, torch.nn.Linear):
703+
assert not mod.bias
704+
out_features = mod.out_features
705+
in_features = mod.in_features
706+
# assert out_features % 8 == 0, "require out_features % 8 == 0"
707+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
708+
709+
assert (
710+
in_features % self.group_size == 0
711+
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
712+
713+
weight = mod.weight.data
714+
"""
715+
if not _check_linear_int4_k(
716+
in_features, self.group_size
717+
):
718+
if self.padding_allowed:
719+
print(
720+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
721+
)
722+
padded_in_features = _calc_padded_size_linear_int4(
723+
in_features, self.group_size
724+
)
725+
weight = F.pad(
726+
weight, pad=(0, padded_in_features - in_features)
727+
)
728+
else:
729+
raise RuntimeError(
730+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
731+
+ "and that group_size"
732+
)
733+
"""
734+
(
735+
weight_int8,
736+
scales,
737+
zeros,
738+
) = group_quantize_tensor_symmetric(
739+
weight.to(self.precision),
740+
4, # n_bit
741+
self.group_size,
742+
self.scales_precision,
743+
)
744+
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
745+
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
746+
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
747+
# TODO: support bias?
748+
749+
return cur_state_dict
750+
751+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
752+
replace_linear_8da4w(
753+
model,
754+
self.group_size,
755+
self.padding_allowed,
756+
self.precision,
757+
self.scales_precision,
758+
)
759+
return model
760+
761+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module:
762+
state_dict = self._create_quantized_state_dict(model)
763+
model = self._convert_for_runtime(model)
764+
# TODO: make it strict
765+
model.load_state_dict(state_dict, strict=False)
766+
return model
767+
768+
689769
class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
690770
# pyre-fixme[3]: Return type must be annotated.
691771
def __init__(

0 commit comments

Comments
 (0)