Skip to content

Commit

Permalink
add qgalore
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 12, 2024
1 parent 3dca6a2 commit 25278e8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
127 changes: 127 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
is_ipex_available,
is_lomo_available,
is_peft_available,
is_q_galore_torch_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -1288,6 +1289,132 @@ def get_optimizer_cls_and_kwargs(
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
elif args.optim in [OptimizerNames.QGALORE_ADAMW_8BIT, OptimizerNames.QGALORE_ADAMW_8BIT_LAYERWISE]:
if not is_q_galore_torch_available():
raise ImportError(
"You need to install `q-galore-torch` in order to use GaLore optimizers"
" install it with `pip install qgalore"
)
from q_galore_torch import QGaLoreAdamW8bit

is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
# TODO: check if this is True
raise NotImplementedError("Layer-wise QGaLore does not support DDP at this time")

optimizer_cls = QGaLoreAdamW8bit

if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use QGaLore optimizers"
)
if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")

logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)

if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)

continue

if not target_module_exists and not all_linear:
continue

galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")

if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)

non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]

# The default args are from the official repository: https://github.com/VITA-Group/Q-GaLore
galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 256)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
"quant": optim_args.pop("quant", True),
"quant_n_bit": optim_args.pop("quant_n_bit", 4),
"quant_group_size": optim_args.pop("quant_group_size", 256),
"cos_threshold": optim_args.pop("cos_threshold", 0.4),
"gamma_proj": optim_args.pop("gamma_proj", 2),
"queue_size": optim_args.pop("queue_size", 5),
}

param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]

if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise QGaLoRE optimizer do not support gradient accumulation !")

optimizer_dict = {}
for param in non_galore_params:
if param.requires_grad:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
# TODO: in the original repo, they multiply update_proj_gap param by 2, to check
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)

def optimizer_hook(param):
if (not hasattr(param, "float_grad")) and param.grad is None:
return
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

id_galore_params = [id(p) for p in galore_params]

# TODO: strange, we are not applying on every param here compared to galore
for param in model.parameters():
if id(param) in id_galore_params or param.requires_grad:
setattr(param, "backward_hook", optimizer_hook)

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})

optimizer_kwargs.update({"params": param_groups})

elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
QGALORE_ADAMW_8BIT = "qgalore_adamw_8bit"
QGALORE_ADAMW_8BIT_LAYERWISE = "qgalore_adamw_8bit_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_q_galore_torch_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_q_galore_torch_available = _is_package_available("q_galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_torchao_available = _is_package_available("torchao")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
Expand Down Expand Up @@ -346,6 +347,10 @@ def is_galore_torch_available():
return _galore_torch_available


def is_q_galore_torch_available():
return _q_galore_torch_available


def is_lomo_available():
return _lomo_available

Expand Down

0 comments on commit 25278e8

Please sign in to comment.