Skip to content

Commit

Permalink
Support adamw_torch_8bit (huggingface#34993)
Browse files Browse the repository at this point in the history
* var

* more

* test
  • Loading branch information
fzyzcjy authored and elvircrn committed Feb 13, 2025
1 parent 332c243 commit 50eb5d0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,10 @@ def optimizer_hook(param):
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
elif args.optim in [
OptimizerNames.ADAMW_TORCH_4BIT,
OptimizerNames.ADAMW_TORCH_8BIT,
]:
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
"0.4.0"
):
Expand All @@ -1631,9 +1634,14 @@ def optimizer_hook(param):
"You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
"Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
)
from torchao.prototype.low_bit_optim import AdamW4bit
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit

optimizer_cls = AdamW4bit
if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
optimizer_cls = AdamW4bit
elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT:
optimizer_cls = AdamW8bit
else:
raise ValueError("Invalid optimizer")
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_ADAMW,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
ADAMW_TORCH_8BIT = "adamw_torch_8bit"
ADEMAMIX = "ademamix"
SGD = "sgd"
ADAGRAD = "adagrad"
Expand Down
7 changes: 7 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5017,6 +5017,13 @@ def hp_name(trial):
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_8BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW8bit,
default_adam_kwargs,
)
)


@require_torch
Expand Down

0 comments on commit 50eb5d0

Please sign in to comment.