From 485eab818e4a126e4952c0d17ffbc937ae980897 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 28 Oct 2024 16:44:18 +0700 Subject: [PATCH] refactor: load * bit handling --- src/axolotl/utils/models.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 51c704429..2b8728405 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -637,9 +637,7 @@ def set_quantization_config(self) -> None: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and ( - "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] - ): + elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -662,9 +660,7 @@ def set_quantization_config(self) -> None: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) - elif self.cfg.adapter == "lora" and ( - "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] - ): + elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: bnb_config = { "load_in_8bit": True, } @@ -677,10 +673,8 @@ def set_quantization_config(self) -> None: # no longer needed per https://github.com/huggingface/transformers/pull/26610 if "quantization_config" in self.model_kwargs or self.cfg.gptq: - if "load_in_8bit" in self.model_kwargs: - del self.model_kwargs["load_in_8bit"] - if "load_in_4bit" in self.model_kwargs: - del self.model_kwargs["load_in_4bit"] + self.model_kwargs.pop("load_in_8bit", None) + self.model_kwargs.pop("load_in_4bit", None) def set_attention_config(self) -> None: """