diff --git a/gptqmodel/integration/optimum/utils.py b/gptqmodel/integration/optimum/utils.py index 26c9324e..782672db 100644 --- a/gptqmodel/integration/optimum/utils.py +++ b/gptqmodel/integration/optimum/utils.py @@ -117,6 +117,91 @@ def get_seqlen(model: nn.Module): def monkey_patch_gptqmodel_into_transformers(): + # monkey_patch transformers.utils.quantization_config.GPTQConfig.post_init() + # Because it checks the auto_gptq version + def post_init(self): + r""" + Safety checker that arguments are correct + """ + from packaging import version + import importlib + from transformers.utils.quantization_config import ExllamaVersion + print("monkey patch postin") + if self.bits not in [2, 3, 4, 8]: + raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + if self.dataset is not None: + if isinstance(self.dataset, str): + if self.dataset in ["ptb", "ptb-new"]: + raise ValueError( + f"""{self.dataset} dataset was deprecated. You can only choose between + ['wikitext2','c4','c4-new']""" + ) + if self.dataset not in ["wikitext2", "c4", "c4-new"]: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + elif not isinstance(self.dataset, list): + raise ValueError( + f"""dataset needs to be either a list of string or a value in + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + + if self.disable_exllama is None and self.use_exllama is None: + # New default behaviour + self.use_exllama = True + elif self.disable_exllama is not None and self.use_exllama is None: + # Follow pattern of old config + logger.warning( + "Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`." + "The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file." + ) + self.use_exllama = not self.disable_exllama + self.disable_exllama = None + elif self.disable_exllama is not None and self.use_exllama is not None: + # Only happens if user explicitly passes in both arguments + raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`") + + if self.exllama_config is None: + self.exllama_config = {"version": ExllamaVersion.ONE} + else: + if "version" not in self.exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key.") + elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + exllama_version = self.exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}" + ) + + if self.bits == 4 and self.use_exllama: + if self.exllama_config["version"] == ExllamaVersion.ONE: + logger.info( + "You have activated exllama backend. Note that you can get better inference " + "speed using exllamav2 kernel by setting `exllama_config`." + ) + elif self.exllama_config["version"] == ExllamaVersion.TWO: + optimum_version = version.parse(importlib.metadata.version("optimum")) + # autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) + # if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): + if optimum_version <= version.parse("1.13.2"): + raise ValueError( + # f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" + f"You need optimum > 1.13.2 . Make sure to have that version installed - detected version : optimum {optimum_version}" + ) + if self.modules_in_block_to_quantize is not None: + optimum_version = version.parse(importlib.metadata.version("optimum")) + if optimum_version < version.parse("1.15.0"): + raise ValueError( + "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ." + ) + + from transformers.utils.quantization_config import GPTQConfig + GPTQConfig.post_init = post_init + from transformers.quantizers import auto from .hf_quantizer_gptq import GptqHfQuantizer diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py index 8109f842..0e2d8159 100644 --- a/tests/test_transformers_integration.py +++ b/tests/test_transformers_integration.py @@ -97,7 +97,6 @@ def test_quant_and_load(self, exllama_version): self.assertResult(model, tokenizer, True, exllama_version, reference_output) def assertResult(self, model, tokenizer, load_quant_model, exllama_version, reference_output): - print("qlinear type", type(model.model.decoder.layers[0].self_attn.k_proj)) if exllama_version == 1: self.assertIsInstance(model.model.decoder.layers[0].self_attn.k_proj, ExllamaQuantLinear) elif exllama_version == 2: