Skip to content

Optim: APOLLO optimizer integration #36062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,97 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

### APOLLO

Approximated Gradient Scaling for Memory Efficient LLM Optimization (APOLLO) is a memory-efficient training strategy that allows full-parameter learning for both pre-training and fine-tuning, while maintaining AdamW-level performance with SGD-like memory efficiency.

* **Ultra-low rank efficiency** → Requires much lower rank than GaLore—even rank 1 (APOLLO-Mini) suffices.
* **No expensive SVD computations** → Unlike GaLore, APOLLO leverages random projection, avoiding training stalls.

You can read more about the method in the [original repository](https://github.com/zhuhanqing/APOLLO) or the [APOLLO: SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).

First, make sure to install APOLLO from its official repository:

```bash
pip install apollo-torch
```

Then, APOLLO optimizers can be used simply by setting `optim="apollo_adamw"` and specifying `optim_target_modules`.
`optim_target_modules` can be a list of strings, regex or full path corresponding to the target module names you want to adapt.
Currently, only Linear layers are considered to use the APOLLO optimizers, i.e., included in `optim_target_modules,` while the remaining models are still using AdamW.


You can also enable layer-wise APOLLO by appending "layerwise" to the optimizer name (optim="apollo_adamw_layerwise"), the same as layer-wise GaLore. This saves additional memory for gradient by performing weight updates layer by layer.

Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-apollo",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```


You can further customize APOLLO’s behavior by passing hyperparameters using `optim_args`.

| Parameter | Description |
|------------------|-------------|
| `rank` | Rank of the auxiliary sub-space used for gradient scaling. <br> **APOLLO (default=256)** → Works well for 1B and 7B models. <br> **APOLLO-Mini (default=1)** |
| `scale_type` | How scaling factors are applied. <br> **`channel`** → Per-channel scaling (used in APOLLO). <br> **`tensor`** → Per-tensor scaling (used in APOLLO-Mini). |
| `scale` | Adjusts gradient updates to stabilize training. <br> **APOLLO (default=1.0)** <br> **APOLLO-Mini (default=128)** |
| `update_proj_gap` | Steps before updating projection matrices. Default: **200**. |
| `proj` | Type of projection. Default: **`random`**. |


<Tip>

The `scale` parameter can be set to `n/r`, where `n` is the original space dimension and `r` is the low-rank space dimension.
Alternatively, you can achieve a similar effect by adjusting the learning rate, while keeping scale at its default value.

</Tip>

For example, you can enable APOLLO-Mini (rank=1 for extreme memory efficiency) by passing `optim_args`:

```python

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="apollo_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="proj=random,rank=1,scale=128.0,scale_type=tensor,update_proj_gap=200",

)
```

### LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
GGUF_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
Expand Down Expand Up @@ -403,6 +404,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_apollo_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
https://github.com/zhuhanqing/APOLLO
"""
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
Expand Down
221 changes: 128 additions & 93 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
find_labels,
is_accelerate_available,
is_apex_available,
is_apollo_torch_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
Expand Down Expand Up @@ -1310,6 +1311,103 @@ def get_optimizer_cls_and_kwargs(
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}

def setup_low_rank_optimizer(
optimizer_name: str,
optimizer_mapping: Dict[str, Any],
optim_kwargs: Dict[str, Any],
is_layerwise_supported: bool = True,
) -> Tuple[Any, Any]:
"""
Helper function to set up low-rank optimizers like GaLore and Apollo.

Args:
optimizer_name (str): Name of the optimizer.
optimizer_mapping (dict): Mapping of optimizer names to their classes.
optim_kwargs (dict): Keyword arguments for the optimizer.
is_layerwise_supported (bool): Whether layerwise optimization is supported.

Returns:
Tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
"""
is_layerwise = optimizer_name.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")

optimizer_cls = optimizer_mapping[optimizer_name]

if args.optim_target_modules is None:
raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
)

if model is None:
raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

target_params = []
target_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)

if not isinstance(module, nn.Linear):
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
)
continue

if not target_module_exists and not all_linear:
continue

target_params.append(module.weight)
target_params_names.append(module_name + ".weight")

if len(target_params) == 0:
raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")

non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
optim_kwargs.update(optim_args)

param_groups = [
{"params": non_target_params},
{"params": target_params, **optim_kwargs},
]

if is_layerwise:
if args.gradient_accumulation_steps != 1:
raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")

optimizer_dict = {}
for param in non_target_params:
optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
for param in target_params:
optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})

optimizer_kwargs.update({"params": param_groups})
return optimizer_cls, optimizer_kwargs

if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
Expand Down Expand Up @@ -1471,10 +1569,6 @@ def get_optimizer_cls_and_kwargs(
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit

is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")

optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
Expand All @@ -1484,105 +1578,46 @@ def get_optimizer_cls_and_kwargs(
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}

optimizer_cls = optimizer_mapping[args.optim]

if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")

logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)

if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)

continue

if not target_module_exists and not all_linear:
continue

galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")

if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)

non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]

galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
}

# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]

if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")

optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
args.optim, optimizer_mapping, galore_optim_kwargs
)
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [
OptimizerNames.APOLLO_ADAMW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE,
]:
if not is_apollo_torch_available():
raise ImportError(
"You need to install `apollo_torch` in order to use APOLLO optimizers"
" install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
)
from apollo_torch import APOLLOAdamW

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
optimizer_mapping = {
OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
}

optimizer_kwargs.update({"params": param_groups})
apollo_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"proj": optim_args.pop("proj", "random"),
"scale_type": optim_args.pop("scale_type", "channel"),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 1.0)),
"proj_type": optim_args.pop("proj_type", "std"),
}

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
args.optim, optimizer_mapping, apollo_optim_kwargs
)
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
Expand Down
Loading