Skip to content

Commit

Permalink
monkey_patch transformers.utils.quantization_config.GPTQConfig.post_i…
Browse files Browse the repository at this point in the history
…nit() (#435)
  • Loading branch information
ZX-ModelCloud authored Oct 12, 2024
1 parent bd8d07e commit d27589f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
85 changes: 85 additions & 0 deletions gptqmodel/integration/optimum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,91 @@ def get_seqlen(model: nn.Module):


def monkey_patch_gptqmodel_into_transformers():
# monkey_patch transformers.utils.quantization_config.GPTQConfig.post_init()
# Because it checks the auto_gptq version
def post_init(self):
r"""
Safety checker that arguments are correct
"""
from packaging import version
import importlib
from transformers.utils.quantization_config import ExllamaVersion
print("monkey patch postin")
if self.bits not in [2, 3, 4, 8]:
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if self.dataset is not None:
if isinstance(self.dataset, str):
if self.dataset in ["ptb", "ptb-new"]:
raise ValueError(
f"""{self.dataset} dataset was deprecated. You can only choose between
['wikitext2','c4','c4-new']"""
)
if self.dataset not in ["wikitext2", "c4", "c4-new"]:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
)
elif not isinstance(self.dataset, list):
raise ValueError(
f"""dataset needs to be either a list of string or a value in
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
)

if self.disable_exllama is None and self.use_exllama is None:
# New default behaviour
self.use_exllama = True
elif self.disable_exllama is not None and self.use_exllama is None:
# Follow pattern of old config
logger.warning(
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
)
self.use_exllama = not self.disable_exllama
self.disable_exllama = None
elif self.disable_exllama is not None and self.use_exllama is not None:
# Only happens if user explicitly passes in both arguments
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")

if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.ONE}
else:
if "version" not in self.exllama_config:
raise ValueError("`exllama_config` needs to have a `version` key.")
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
exllama_version = self.exllama_config["version"]
raise ValueError(
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
)

if self.bits == 4 and self.use_exllama:
if self.exllama_config["version"] == ExllamaVersion.ONE:
logger.info(
"You have activated exllama backend. Note that you can get better inference "
"speed using exllamav2 kernel by setting `exllama_config`."
)
elif self.exllama_config["version"] == ExllamaVersion.TWO:
optimum_version = version.parse(importlib.metadata.version("optimum"))
# autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
# if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
if optimum_version <= version.parse("1.13.2"):
raise ValueError(
# f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
f"You need optimum > 1.13.2 . Make sure to have that version installed - detected version : optimum {optimum_version}"
)
if self.modules_in_block_to_quantize is not None:
optimum_version = version.parse(importlib.metadata.version("optimum"))
if optimum_version < version.parse("1.15.0"):
raise ValueError(
"You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
)

from transformers.utils.quantization_config import GPTQConfig
GPTQConfig.post_init = post_init

from transformers.quantizers import auto

from .hf_quantizer_gptq import GptqHfQuantizer
Expand Down
1 change: 0 additions & 1 deletion tests/test_transformers_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_quant_and_load(self, exllama_version):
self.assertResult(model, tokenizer, True, exllama_version, reference_output)

def assertResult(self, model, tokenizer, load_quant_model, exllama_version, reference_output):
print("qlinear type", type(model.model.decoder.layers[0].self_attn.k_proj))
if exllama_version == 1:
self.assertIsInstance(model.model.decoder.layers[0].self_attn.k_proj, ExllamaQuantLinear)
elif exllama_version == 2:
Expand Down

0 comments on commit d27589f

Please sign in to comment.