Skip to content

Commit

Permalink
refactor: load * bit handling
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Oct 28, 2024
1 parent b9c79f8 commit 485eab8
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,7 @@ def set_quantization_config(self) -> None:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
Expand All @@ -662,9 +660,7 @@ def set_quantization_config(self) -> None:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
Expand All @@ -677,10 +673,8 @@ def set_quantization_config(self) -> None:

# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
if "load_in_8bit" in self.model_kwargs:
del self.model_kwargs["load_in_8bit"]
if "load_in_4bit" in self.model_kwargs:
del self.model_kwargs["load_in_4bit"]
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)

def set_attention_config(self) -> None:
"""
Expand Down

0 comments on commit 485eab8

Please sign in to comment.