|
36 | 36 | from .quant_primitives import (
|
37 | 37 | get_group_qparams_symmetric,
|
38 | 38 | per_token_dynamic_quant,
|
| 39 | + group_quantize_tensor_symmetric, |
39 | 40 | )
|
40 |
| -from typing import Dict, Tuple |
| 41 | +from typing import Dict, Tuple, Any |
| 42 | +import logging |
41 | 43 |
|
42 | 44 | __all__ = [
|
43 | 45 | "apply_weight_only_int8_quant",
|
|
54 | 56 | ############################# Unified Quantization APIs ##############################
|
55 | 57 | # API 1, single quantize call to create a quantized model with quantized state_dict
|
56 | 58 | class Quantizer:
|
57 |
| - # pyre-fixme[2]: Parameter must be annotated. |
58 |
| - def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 59 | + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
59 | 60 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
60 | 61 | pass
|
61 | 62 |
|
62 | 63 |
|
63 | 64 | # API 2, flow that needs calibration or training
|
64 | 65 | class TwoStepQuantizer:
|
65 |
| - # pyre-fixme[2]: Parameter must be annotated. |
66 |
| - def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 66 | + def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
67 | 67 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
68 | 68 | pass
|
69 | 69 |
|
70 |
| - # pyre-fixme[2]: Parameter must be annotated. |
71 |
| - def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 70 | + def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
72 | 71 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
73 | 72 | pass
|
74 | 73 |
|
@@ -260,7 +259,7 @@ def replace_conv2d_1x1(conv):
|
260 | 259 | MultiInput,
|
261 | 260 | )
|
262 | 261 | else:
|
263 |
| - print("lm_eval not available, skip defining GPTQQuantizer") |
| 262 | + logging.info("lm_eval not available, skip defining GPTQQuantizer") |
264 | 263 |
|
265 | 264 |
|
266 | 265 | class GPTQQuantizer(Quantizer):
|
@@ -442,11 +441,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
|
442 | 441 |
|
443 | 442 | @torch.no_grad()
|
444 | 443 | # pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
|
445 |
| - def quantize( |
446 |
| - self, |
447 |
| - # pyre-fixme[2]: Parameter must be annotated. |
448 |
| - model, |
449 |
| - ) -> torch.nn.Module: |
| 444 | + def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module: |
450 | 445 | state_dict = self._create_quantized_state_dict(
|
451 | 446 | model,
|
452 | 447 | # pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
|
@@ -686,6 +681,91 @@ def replace_linear_8da4w(
|
686 | 681 | )
|
687 | 682 |
|
688 | 683 |
|
| 684 | +class Int8DynActInt4WeightQuantizer(Quantizer): |
| 685 | + def __init__( |
| 686 | + self, |
| 687 | + group_size: int = 256, |
| 688 | + padding_allowed: bool = False, |
| 689 | + precision: torch.dtype = torch.float32, |
| 690 | + scales_precision: torch.dtype = torch.float32, |
| 691 | + ) -> None: |
| 692 | + self.group_size: int = group_size |
| 693 | + self.padding_allowed: bool = padding_allowed |
| 694 | + self.precision: torch.dtype = precision |
| 695 | + self.scales_precision: torch.dtype = scales_precision |
| 696 | + # assert group_size in [32, 64, 128, 256] |
| 697 | + |
| 698 | + @torch.no_grad() |
| 699 | + def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]: |
| 700 | + cur_state_dict = model.state_dict() |
| 701 | + for fqn, mod in model.named_modules(): |
| 702 | + if isinstance(mod, torch.nn.Linear): |
| 703 | + assert not mod.bias |
| 704 | + out_features = mod.out_features |
| 705 | + in_features = mod.in_features |
| 706 | + # assert out_features % 8 == 0, "require out_features % 8 == 0" |
| 707 | + print(f"linear: {fqn}, in={in_features}, out={out_features}") |
| 708 | + |
| 709 | + assert ( |
| 710 | + in_features % self.group_size == 0 |
| 711 | + ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" |
| 712 | + |
| 713 | + weight = mod.weight.data |
| 714 | + """ |
| 715 | + if not _check_linear_int4_k( |
| 716 | + in_features, self.group_size |
| 717 | + ): |
| 718 | + if self.padding_allowed: |
| 719 | + print( |
| 720 | + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" |
| 721 | + ) |
| 722 | + padded_in_features = _calc_padded_size_linear_int4( |
| 723 | + in_features, self.group_size |
| 724 | + ) |
| 725 | + weight = F.pad( |
| 726 | + weight, pad=(0, padded_in_features - in_features) |
| 727 | + ) |
| 728 | + else: |
| 729 | + raise RuntimeError( |
| 730 | + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " |
| 731 | + + "and that group_size" |
| 732 | + ) |
| 733 | + """ |
| 734 | + ( |
| 735 | + weight_int8, |
| 736 | + scales, |
| 737 | + zeros, |
| 738 | + ) = group_quantize_tensor_symmetric( |
| 739 | + weight.to(self.precision), |
| 740 | + 4, # n_bit |
| 741 | + self.group_size, |
| 742 | + self.scales_precision, |
| 743 | + ) |
| 744 | + cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") |
| 745 | + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") |
| 746 | + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") |
| 747 | + # TODO: support bias? |
| 748 | + |
| 749 | + return cur_state_dict |
| 750 | + |
| 751 | + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: |
| 752 | + replace_linear_8da4w( |
| 753 | + model, |
| 754 | + self.group_size, |
| 755 | + self.padding_allowed, |
| 756 | + self.precision, |
| 757 | + self.scales_precision, |
| 758 | + ) |
| 759 | + return model |
| 760 | + |
| 761 | + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
| 762 | + state_dict = self._create_quantized_state_dict(model) |
| 763 | + model = self._convert_for_runtime(model) |
| 764 | + # TODO: make it strict |
| 765 | + model.load_state_dict(state_dict, strict=False) |
| 766 | + return model |
| 767 | + |
| 768 | + |
689 | 769 | class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
|
690 | 770 | # pyre-fixme[3]: Return type must be annotated.
|
691 | 771 | def __init__(
|
|
0 commit comments