From 015a76f3a17de9580d8faf7e4079f6f90828c63f Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> Date: Wed, 24 Jul 2024 05:59:08 +0800 Subject: [PATCH] [FIX] allow auto_round lm_head quantization (#282) * enable auto_round lm_head quantize * Update base.py --------- Co-authored-by: LRL-ModelCloud Co-authored-by: Qubitium-ModelCloud --- gptqmodel/models/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f7a0dcd9..085a758f 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -170,9 +170,8 @@ def quantize( logger.warning("According to the issue https://github.com/ModelCloud/GPTQModel/issues/278, transformers version 4.43.0 has broken batch_size. until the issue is resolved, hard set the batch_size to 1.") batch_size = 1 - # TODO: lm_head quantization is yet ready but pending - if self.quantize_config.lm_head: - raise ValueError("lm_head quantization is currently inference only and not applicable for quantization. Please set `lm_head=False`.") + if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig): + raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.") if len(calibration_dataset) == 0: raise ValueError("Calibration dataset must not be empty.")