|
33 | 33 | from .quant_primitives import (
|
34 | 34 | get_group_qparams_symmetric,
|
35 | 35 | per_token_dynamic_quant,
|
36 |
| - group_quantize_tensor_symmetric, |
37 | 36 | )
|
38 |
| -from typing import Dict, Tuple, Any |
39 |
| -import logging |
| 37 | +from typing import Dict, Tuple |
40 | 38 |
|
41 | 39 | __all__ = [
|
42 | 40 | "apply_weight_only_int8_quant",
|
|
52 | 50 | ############################# Unified Quantization APIs ##############################
|
53 | 51 | # API 1, single quantize call to create a quantized model with quantized state_dict
|
54 | 52 | class Quantizer:
|
55 |
| - def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
| 53 | + # pyre-fixme[2]: Parameter must be annotated. |
| 54 | + def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
56 | 55 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
57 | 56 | pass
|
58 | 57 |
|
59 | 58 |
|
60 | 59 | # API 2, flow that needs calibration or training
|
61 | 60 | class TwoStepQuantizer:
|
62 |
| - def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
| 61 | + # pyre-fixme[2]: Parameter must be annotated. |
| 62 | + def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
63 | 63 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
64 | 64 | pass
|
65 | 65 |
|
66 |
| - def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
| 66 | + # pyre-fixme[2]: Parameter must be annotated. |
| 67 | + def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
67 | 68 | # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`.
|
68 | 69 | pass
|
69 | 70 |
|
@@ -249,7 +250,7 @@ def replace_conv2d_1x1(conv):
|
249 | 250 | MultiInput,
|
250 | 251 | )
|
251 | 252 | else:
|
252 |
| - logging.info("lm_eval not available, skip defining GPTQQuantizer") |
| 253 | + print("lm_eval not available, skip defining GPTQQuantizer") |
253 | 254 |
|
254 | 255 |
|
255 | 256 | class GPTQQuantizer(Quantizer):
|
@@ -431,7 +432,11 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
|
431 | 432 |
|
432 | 433 | @torch.no_grad()
|
433 | 434 | # pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently.
|
434 |
| - def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module: |
| 435 | + def quantize( |
| 436 | + self, |
| 437 | + # pyre-fixme[2]: Parameter must be annotated. |
| 438 | + model, |
| 439 | + ) -> torch.nn.Module: |
435 | 440 | state_dict = self._create_quantized_state_dict(
|
436 | 441 | model,
|
437 | 442 | # pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`.
|
@@ -665,91 +670,6 @@ def replace_linear_8da4w(
|
665 | 670 | )
|
666 | 671 |
|
667 | 672 |
|
668 |
| -class Int8DynActInt4WeightQuantizer(Quantizer): |
669 |
| - def __init__( |
670 |
| - self, |
671 |
| - group_size: int = 256, |
672 |
| - padding_allowed: bool = False, |
673 |
| - precision: torch.dtype = torch.float32, |
674 |
| - scales_precision: torch.dtype = torch.float32, |
675 |
| - ) -> None: |
676 |
| - self.group_size: int = group_size |
677 |
| - self.padding_allowed: bool = padding_allowed |
678 |
| - self.precision: torch.dtype = precision |
679 |
| - self.scales_precision: torch.dtype = scales_precision |
680 |
| - # assert group_size in [32, 64, 128, 256] |
681 |
| - |
682 |
| - @torch.no_grad() |
683 |
| - def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]: |
684 |
| - cur_state_dict = model.state_dict() |
685 |
| - for fqn, mod in model.named_modules(): |
686 |
| - if isinstance(mod, torch.nn.Linear): |
687 |
| - assert mod.bias is not None |
688 |
| - out_features = mod.out_features |
689 |
| - in_features = mod.in_features |
690 |
| - # assert out_features % 8 == 0, "require out_features % 8 == 0" |
691 |
| - print(f"linear: {fqn}, in={in_features}, out={out_features}") |
692 |
| - |
693 |
| - assert ( |
694 |
| - in_features % self.group_size == 0 |
695 |
| - ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" |
696 |
| - |
697 |
| - weight = mod.weight.data |
698 |
| - """ |
699 |
| - if not _check_linear_int4_k( |
700 |
| - in_features, self.group_size |
701 |
| - ): |
702 |
| - if self.padding_allowed: |
703 |
| - print( |
704 |
| - f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" |
705 |
| - ) |
706 |
| - padded_in_features = _calc_padded_size_linear_int4( |
707 |
| - in_features, self.group_size |
708 |
| - ) |
709 |
| - weight = F.pad( |
710 |
| - weight, pad=(0, padded_in_features - in_features) |
711 |
| - ) |
712 |
| - else: |
713 |
| - raise RuntimeError( |
714 |
| - f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " |
715 |
| - + "and that group_size" |
716 |
| - ) |
717 |
| - """ |
718 |
| - ( |
719 |
| - weight_int8, |
720 |
| - scales, |
721 |
| - zeros, |
722 |
| - ) = group_quantize_tensor_symmetric( |
723 |
| - weight.to(self.precision), |
724 |
| - 4, # n_bit |
725 |
| - self.group_size, |
726 |
| - self.scales_precision, |
727 |
| - ) |
728 |
| - cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") |
729 |
| - cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") |
730 |
| - cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") |
731 |
| - # TODO: support bias? |
732 |
| - |
733 |
| - return cur_state_dict |
734 |
| - |
735 |
| - def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: |
736 |
| - replace_linear_8da4w( |
737 |
| - model, |
738 |
| - self.group_size, |
739 |
| - self.padding_allowed, |
740 |
| - self.precision, |
741 |
| - self.scales_precision, |
742 |
| - ) |
743 |
| - return model |
744 |
| - |
745 |
| - def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: |
746 |
| - state_dict = self._create_quantized_state_dict(model) |
747 |
| - model = self._convert_for_runtime(model) |
748 |
| - # TODO: make it strict |
749 |
| - model.load_state_dict(state_dict, strict=False) |
750 |
| - return model |
751 |
| - |
752 |
| - |
753 | 673 | class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer):
|
754 | 674 | # pyre-fixme[3]: Return type must be annotated.
|
755 | 675 | def __init__(
|
|
0 commit comments