Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-bit QLoRA via bitsandbytes (4-bit base model + LoRA) #23479

Merged
merged 66 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
596b1c0
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
e66d556
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
5cdc176
Added and fixed optimizer tests.
TimDettmers May 8, 2023
0773ae5
Style and quality checks.
TimDettmers May 8, 2023
24c49e5
Initial draft. Some tests fail.
TimDettmers May 10, 2023
68b8ba4
Merge remote-tracking branch 'origin/main' into bnb_4bit
TimDettmers May 10, 2023
03b4d78
Fixed dtype bug.
TimDettmers May 10, 2023
524be44
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
06cf851
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
42e1095
Merge remote-tracking branch 'forked/bnb_paged_optimizers' into forke…
TimDettmers May 11, 2023
2525aee
Added fix for fp32 layer norms and bf16 compute in LLaMA.
TimDettmers May 11, 2023
cb7e54a
Merge remote-tracking branch 'origin/main' into bnb_beta
TimDettmers May 19, 2023
90412ab
Initial draft. Some tests fail.
TimDettmers May 10, 2023
0e6015b
Fixed dtype bug.
TimDettmers May 10, 2023
866886c
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
4c5ebf1
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
170812b
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
6e0d3ac
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
1582692
Added and fixed optimizer tests.
TimDettmers May 8, 2023
1f25846
Style and quality checks.
TimDettmers May 8, 2023
56110ec
Fixing issues for PR #23479.
TimDettmers May 20, 2023
80396d0
Added fix for fp32 layer norms and bf16 compute in LLaMA.
TimDettmers May 11, 2023
d4b4e4d
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into b…
TimDettmers May 20, 2023
6263752
Reverted variable name change.
TimDettmers May 20, 2023
831fc4a
Initial draft. Some tests fail.
TimDettmers May 10, 2023
b42644a
Fixed dtype bug.
TimDettmers May 10, 2023
9cd4319
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
d68e564
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
e8dcb57
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
ad30995
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
f1b2ab6
Added and fixed optimizer tests.
TimDettmers May 8, 2023
8b2e43d
Style and quality checks.
TimDettmers May 8, 2023
84cd0b3
Added missing tests.
TimDettmers May 20, 2023
61d2993
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into f…
TimDettmers May 20, 2023
33dde75
Fixup changes.
TimDettmers May 20, 2023
1d830b5
Added fixup changes.
TimDettmers May 20, 2023
5c1a5e0
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into b…
TimDettmers May 20, 2023
2f15b6e
Missed some variables to rename.
TimDettmers May 20, 2023
617b58c
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada May 22, 2023
ea7175d
revert trainer tests
younesbelkada May 22, 2023
aac113d
revert test trainer
younesbelkada May 22, 2023
e43237d
another revert
younesbelkada May 22, 2023
13c86fd
fix tests and safety checkers
younesbelkada May 22, 2023
c72f302
protect import
younesbelkada May 22, 2023
7b1b1e6
simplify a bit
younesbelkada May 22, 2023
cf393cf
Update src/transformers/trainer.py
younesbelkada May 22, 2023
f19d80c
few fixes
younesbelkada May 22, 2023
ba287ff
add warning
younesbelkada May 22, 2023
1030921
replace with `load_in_kbit = load_in_4bit or load_in_8bit`
younesbelkada May 22, 2023
1cae462
fix test
younesbelkada May 22, 2023
25f762e
fix tests
younesbelkada May 22, 2023
2f43dc1
this time fix tests
younesbelkada May 22, 2023
a63b649
safety checker
younesbelkada May 22, 2023
49501db
add docs
younesbelkada May 22, 2023
4642523
revert torch_dtype
younesbelkada May 22, 2023
a6ba77b
Apply suggestions from code review
younesbelkada May 22, 2023
27cdff6
multiple fixes
younesbelkada May 22, 2023
b2bc3ab
update docs
younesbelkada May 22, 2023
976f7d0
version checks and multiple fixes
younesbelkada May 22, 2023
9c4946e
replace `is_loaded_in_kbit`
younesbelkada May 22, 2023
6f4f4dc
replace `load_in_kbit`
younesbelkada May 22, 2023
5359b59
change methods names
younesbelkada May 22, 2023
0c0bb65
better checks
younesbelkada May 22, 2023
f4a2a0b
oops
younesbelkada May 22, 2023
13a2ad7
oops
younesbelkada May 22, 2023
0b05092
address final comments
younesbelkada May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def to_json_string(self, use_diff: bool = True) -> str:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
Expand Down
103 changes: 56 additions & 47 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def _load_state_dict_into_meta_model(
state_dict_index=None,
dtype=None,
load_in_8bit=False,
load_in_4bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
Expand All @@ -627,8 +628,10 @@ def _load_state_dict_into_meta_model(
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.

if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
if (load_in_4bit == True and load_in_8bit == True):
raise ValueError('You cannot set load_in_4bit=True and load_in_8bit=True at the same time! Choose one option.')
if load_in_8bit or load_in_4bit:
from .utils.bitsandbytes import set_module_kbit_tensor_to_device

error_msgs = []

Expand Down Expand Up @@ -699,12 +702,14 @@ def _load_state_dict_into_meta_model(
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]


if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit:
elif not (load_in_8bit or load_in_4bit):
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
Expand All @@ -714,7 +719,7 @@ def _load_state_dict_into_meta_model(
fp16_statistics = None

if "SCB" not in param_name:
set_module_8bit_tensor_to_device(
set_module_kbit_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
)

Expand Down Expand Up @@ -1700,6 +1705,15 @@ def save_pretrained(
UserWarning,
)

if getattr(self, "is_loaded_in_4bit", False):
warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
UserWarning,
)
raise NotImplementedError("You are calling `save_pretrained` on a 4-bit converted model. \
This is currently not supported")
TimDettmers marked this conversation as resolved.
Show resolved Hide resolved

if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
Expand Down Expand Up @@ -1876,29 +1890,29 @@ def get_memory_footprint(self, return_buffers=True):

def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
if getattr(self, "is_loaded_in_kbit", False):
raise ValueError(
"`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
"`.to` is not supported for `k-bit` models. Please use the model as it is, since the"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
return super().to(*args, **kwargs)

def half(self, *args):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
if getattr(self, "is_loaded_in_kbit", False):
raise ValueError(
"`.half()` is not supported for `8-bit` models. Please use the model as it is, since the"
"`.half()` is not supported for `k-bit` models. Please use the model as it is, since the"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
" model has already been casted to the correct `dtype`."
)
else:
return super().half(*args)

def float(self, *args):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
if getattr(self, "is_loaded_in_kbit", False):
raise ValueError(
"`.float()` is not supported for `8-bit` models. Please use the model as it is, since the"
"`.float()` is not supported for `k-bit` models. Please use the model as it is, since the"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
" model has already been casted to the correct `dtype`."
)
else:
Expand Down Expand Up @@ -2156,6 +2170,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False)
quantization_config = kwargs.pop("quantization_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
Expand Down Expand Up @@ -2194,10 +2209,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if quantization_config is None:
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
config_dict={"load_in_8bit": load_in_8bit}, return_unused_kwargs=True, **kwargs
config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit" : load_in_4bit}, return_unused_kwargs=True, **kwargs
)
elif quantization_config is not None:
load_in_8bit = quantization_config.load_in_8bit
load_in_4bit = quantization_config.load_in_4bit

quantization_config_kwargs = {
k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters
Expand All @@ -2215,30 +2231,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True

if load_in_8bit:
if load_in_8bit or load_in_4bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
if torch_dtype != torch.float16:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
logger.warning(
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
"requirements of `bitsandbytes` to enable model loading in mixed int8. "
"Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning."
)
torch_dtype = torch.float16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the Accelerate PR, is this not necessary anymore in the current released version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was never needed in the first place. The problem with this is that it does not allow for other data types which are necessary for mixed-precision 4-bit training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this should be re-visited - it currently breaks some tests for existing 8bit integration for the following reason:

  • If a user calls model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") as before this PR, they would get a model that is first converted in float16, then casted in float16, meaning that all the non-Linear layers would be casted in fp16, including embedding layers, etc.

As some existing systems relies on that I would replace that warning with an information that says that if torch_dtype is not set, it would let the non Linear modules in the current dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a proper warning message!


if torch_dtype is None:
torch_dtype = torch.float32

if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
"A device map needs to be passed to run convert models into 8-bit and 4-bit formats. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
"Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)

Expand Down Expand Up @@ -2296,7 +2307,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit = quantization_config.load_in_8bit

if load_in_8bit:
torch_dtype = torch.float16
if torch_dtype is None:
torch_dtype = torch.float32

if device_map is None:
device_map = "auto"
Expand Down Expand Up @@ -2582,7 +2594,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (
(cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16
(cls._keep_in_fp32_modules is not None) and is_accelerate_available() and
(torch_dtype == torch.float16 or load_in_4bit or load_in_8bit)
)
if (
(cls._keep_in_fp32_modules is not None)
Expand Down Expand Up @@ -2611,7 +2624,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif load_in_8bit or low_cpu_mem_usage:
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

with ContextManagers(init_contexts):
Expand All @@ -2624,20 +2637,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
keep_in_fp32_modules = []

if load_in_8bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
if load_in_8bit or load_in_4bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_with_bnb_linear

load_in_8bit_skip_modules = quantization_config.llm_int8_skip_modules
bnb_kbit_skip_modules = quantization_config.bnb_kbit_skip_modules
load_in_8bit_threshold = quantization_config.llm_int8_threshold
load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload

logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if load_in_8bit_skip_modules is None:
if bnb_kbit_skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules
modules_to_not_convert = bnb_kbit_skip_modules

if not isinstance(modules_to_not_convert, list):
modules_to_not_convert = [modules_to_not_convert]
Expand All @@ -2657,12 +2670,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

modules_to_not_convert.extend(keys_on_cpu)

model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
model = replace_with_bnb_linear(
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
)

# training in 8-bit is only available in 0.37.0+
model._is_int8_training_enabled = version.parse(
model._is_kbit_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes")
) >= version.parse("0.37.0")

Expand All @@ -2671,15 +2684,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if isinstance(device_map, str):
special_dtypes = {}
if load_in_8bit:
special_dtypes.update(
{
name: torch_dtype
for name, _ in model.named_parameters()
if any(m in name for m in modules_to_not_convert)
}
)

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
special_dtypes.update(
{
name: torch.float32
Expand Down Expand Up @@ -2720,7 +2724,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model.tie_weights()
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

if load_in_8bit:
if load_in_8bit or load_in_4bit:
# The LM head / tied weights or any last module can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
Expand Down Expand Up @@ -2796,10 +2800,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
keep_in_fp32_modules=keep_in_fp32_modules,
)

model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
model.is_loaded_in_kbit = load_in_8bit or load_in_4bit
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_loaded_in_kbit is not a good name. All models are lodaded in a certain kbit precision for k in {4, 8, 16, 32}. is_quantized would be more adapted.


# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -2863,11 +2870,12 @@ def _load_pretrained_model(
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
load_in_4bit=False,
keep_in_fp32_modules=None,
):
is_safetensors = False
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
if load_in_8bit or load_in_4bit:
from .utils.bitsandbytes import set_module_kbit_tensor_to_device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this function.


if device_map is not None and "disk" in device_map.values():
archive_file = (
Expand Down Expand Up @@ -2973,10 +2981,10 @@ def _fix_key(key):
target_dtype = torch.float32

if param.device == torch.device("meta"):
if not load_in_8bit:
if not (load_in_8bit or load_in_4bit):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
else:
set_module_8bit_tensor_to_device(
set_module_kbit_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)

Expand Down Expand Up @@ -3135,6 +3143,7 @@ def _find_mismatched_keys(
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,12 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states
hidden_states = (self.weight * hidden_states).to(input_dtype)
return hidden_states
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved


class LlamaRotaryEmbedding(torch.nn.Module):
Expand Down
34 changes: 33 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def __init__(

# At this stage the model is already loaded
if getattr(model, "is_loaded_in_8bit", False):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if getattr(model, "_is_int8_training_enabled", False):
if getattr(model, "_is_kbit_training_enabled", False):
logger.info(
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
Expand Down Expand Up @@ -1170,6 +1170,38 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim in [
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
OptimizerNames.ADAMW_BNB,
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
OptimizerNames.PAGED_LION_8BIT,
]:
try:
from bitsandbytes.optim import AdamW, Lion

is_paged = False
optim_bits = 32
optimizer_cls = None
additional_optim_kwargs = adam_kwargs
if "paged" in args.optim:
is_paged = True
if "8bit" in args.optim:
optim_bits = 8
if "adam" in args.optim:
optimizer_cls = AdamW
elif "lion" in args.optim:
optimizer_cls = Lion
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
optimizer_kwargs.update(additional_optim_kwargs)
optimizer_kwargs.update(bnb_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,17 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
TimDettmers marked this conversation as resolved.
Show resolved Hide resolved
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"


@dataclass
Expand Down
Loading