diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 51c704429..2b8728405 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -637,9 +637,7 @@ def set_quantization_config(self) -> None: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and ( - "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] - ): + elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -662,9 +660,7 @@ def set_quantization_config(self) -> None: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) - elif self.cfg.adapter == "lora" and ( - "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] - ): + elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: bnb_config = { "load_in_8bit": True, } @@ -677,10 +673,8 @@ def set_quantization_config(self) -> None: # no longer needed per https://github.com/huggingface/transformers/pull/26610 if "quantization_config" in self.model_kwargs or self.cfg.gptq: - if "load_in_8bit" in self.model_kwargs: - del self.model_kwargs["load_in_8bit"] - if "load_in_4bit" in self.model_kwargs: - del self.model_kwargs["load_in_4bit"] + self.model_kwargs.pop("load_in_8bit", None) + self.model_kwargs.pop("load_in_4bit", None) def set_attention_config(self) -> None: """