|
33 | 33 | from .quant_primitives import (
|
34 | 34 | get_group_qparams_symmetric,
|
35 | 35 | per_token_dynamic_quant,
|
| 36 | + group_quantize_tensor_symmetric, |
36 | 37 | )
|
37 |
| -from typing import Dict, Tuple |
| 38 | +from typing import Dict, Tuple, Any |
| 39 | +import logging |
38 | 40 |
|
39 | 41 | __all__ = [
|
40 | 42 | "apply_weight_only_int8_quant",
|
|
50 | 52 | ############################# Unified Quantization APIs ##############################
|
51 | 53 | # API 1, single quantize call to create a quantized model with quantized state_dict
|
52 | 54 | class Quantizer:
|
53 |
| - # pyre-fixme[2]: Parameter must be annotated. |
54 |
| - def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 55 | + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
55 | 56 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
56 | 57 | pass
|
57 | 58 |
|
58 | 59 |
|
59 | 60 | # API 2, flow that needs calibration or training
|
60 | 61 | class TwoStepQuantizer:
|
61 |
| - # pyre-fixme[2]: Parameter must be annotated. |
62 |
| - def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 62 | + def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
63 | 63 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
64 | 64 | pass
|
65 | 65 |
|
66 | 66 | # pyre-fixme[2]: Parameter must be annotated.
|
67 |
| - def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
| 67 | + def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
68 | 68 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
69 | 69 | pass
|
70 | 70 |
|
@@ -250,7 +250,7 @@ def replace_conv2d_1x1(conv):
|
250 | 250 | MultiInput,
|
251 | 251 | )
|
252 | 252 | else:
|
253 |
| - print("lm_eval not available, skip defining GPTQQuantizer") |
| 253 | + logging.info("lm_eval not available, skip defining GPTQQuantizer") |
254 | 254 |
|
255 | 255 |
|
256 | 256 | class GPTQQuantizer(Quantizer):
|
@@ -432,11 +432,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
|
432 | 432 |
|
433 | 433 | @torch.no_grad()
|
434 | 434 | # pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
|
435 |
| - def quantize( |
436 |
| - self, |
437 |
| - # pyre-fixme[2]: Parameter must be annotated. |
438 |
| - model, |
439 |
| - ) -> torch.nn.Module: |
| 435 | + def quantize(self, model: torch.nn.Module, **kwargs) -> torch.nn.Module: |
440 | 436 | state_dict = self._create_quantized_state_dict(
|
441 | 437 | model,
|
442 | 438 | # pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
|
@@ -670,6 +666,91 @@ def replace_linear_8da4w(
|
670 | 666 | )
|
671 | 667 |
|
672 | 668 |
|
| 669 | +class Int8DynActInt4WeightQuantizer(Quantizer): |
| 670 | + def __init__( |
| 671 | + self, |
| 672 | + group_size: int = 256, |
| 673 | + padding_allowed: bool = False, |
| 674 | + precision: torch.dtype = torch.float32, |
| 675 | + scales_precision: torch.dtype = torch.float32, |
| 676 | + ) -> None: |
| 677 | + self.group_size: int = group_size |
| 678 | + self.padding_allowed: bool = padding_allowed |
| 679 | + self.precision: torch.dtype = precision |
| 680 | + self.scales_precision: torch.dtype = scales_precision |
| 681 | + # assert group_size in [32, 64, 128, 256] |
| 682 | + |
| 683 | + @torch.no_grad() |
| 684 | + def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]: |
| 685 | + cur_state_dict = model.state_dict() |
| 686 | + for fqn, mod in model.named_modules(): |
| 687 | + if isinstance(mod, torch.nn.Linear): |
| 688 | + assert mod.bias is not None |
| 689 | + out_features = mod.out_features |
| 690 | + in_features = mod.in_features |
| 691 | + # assert out_features % 8 == 0, "require out_features % 8 == 0" |
| 692 | + print(f"linear: {fqn}, in={in_features}, out={out_features}") |
| 693 | + |
| 694 | + assert ( |
| 695 | + in_features % self.group_size == 0 |
| 696 | + ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" |
| 697 | + |
| 698 | + weight = mod.weight.data |
| 699 | + """ |
| 700 | + if not _check_linear_int4_k( |
| 701 | + in_features, self.group_size |
| 702 | + ): |
| 703 | + if self.padding_allowed: |
| 704 | + print( |
| 705 | + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" |
| 706 | + ) |
| 707 | + padded_in_features = _calc_padded_size_linear_int4( |
| 708 | + in_features, self.group_size |
| 709 | + ) |
| 710 | + weight = F.pad( |
| 711 | + weight, pad=(0, padded_in_features - in_features) |
| 712 | + ) |
| 713 | + else: |
| 714 | + raise RuntimeError( |
| 715 | + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " |
| 716 | + + "and that group_size" |
| 717 | + ) |
| 718 | + """ |
| 719 | + ( |
| 720 | + weight_int8, |
| 721 | + scales, |
| 722 | + zeros, |
| 723 | + ) = group_quantize_tensor_symmetric( |
| 724 | + weight.to(self.precision), |
| 725 | + 4, # n_bit |
| 726 | + self.group_size, |
| 727 | + self.scales_precision, |
| 728 | + ) |
| 729 | + cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") |
| 730 | + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") |
| 731 | + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") |
| 732 | + # TODO: support bias? |
| 733 | + |
| 734 | + return cur_state_dict |
| 735 | + |
| 736 | + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: |
| 737 | + replace_linear_8da4w( |
| 738 | + model, |
| 739 | + self.group_size, |
| 740 | + self.padding_allowed, |
| 741 | + self.precision, |
| 742 | + self.scales_precision, |
| 743 | + ) |
| 744 | + return model |
| 745 | + |
| 746 | + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
| 747 | + state_dict = self._create_quantized_state_dict(model) |
| 748 | + model = self._convert_for_runtime(model) |
| 749 | + # TODO: make it strict |
| 750 | + model.load_state_dict(state_dict, strict=False) |
| 751 | + return model |
| 752 | + |
| 753 | + |
673 | 754 | class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
|
674 | 755 | # pyre-fixme[3]: Return type must be annotated.
|
675 | 756 | def __init__(
|
|
0 commit comments