diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index d057fc3c94..6da55a503f 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -30,7 +30,7 @@ import torch import transformers from datasets import load_dataset -from peft import LoraConfig, TaskType, get_peft_model, tuners +from peft import AdaLoraConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners from peft.utils.other import fsdp_auto_wrap_policy from transformers import ( AutoConfig, @@ -278,14 +278,51 @@ class FinetuneArguments: default=0.05, metadata={"help": "Dropout parameter in the LoRA method."}, ) - lora_target_modules: List[str] = field( + target_modules: List[str] = field( default_factory=lambda: None, - metadata={"help": "Target modules for the LoRA method."}, + metadata={"help": "Target modules for the LoRA/IA3/AdaLoRA method."}, ) train_on_inputs: bool = field( default=True, metadata={"help": "if False, masks out inputs in loss"}, ) + adalora_init_r: int = field( + default=12, + metadata={"help": "Initial AdaLoRA rank"}, + ) + adalora_target_r: int = field( + default=4, + metadata={"help": "Target AdaLoRA rank"}, + ) + adalora_tinit: int = field( + default=50, + metadata={"help": "Number of warmup steps for AdaLoRA wherein no pruning is performed"}, + ) + adalora_tfinal: int = field( + default=100, + metadata={ + "help": "Fix the resulting budget distribution and fine-tune the model for tfinal steps when using AdaLoRA" + }, + ) + adalora_delta_t: int = field( + default=10, + metadata={"help": "Interval of steps for AdaLoRA to update rank"}, + ) + adalora_orth_reg_weight: float = field( + default=0.5, + metadata={"help": "Orthogonal regularization weight for AdaLoRA"}, + ) + peft_type: str = field( + default="lora", + metadata={ + "help": ("The PEFT type to use."), + "choices": ["lora", "ia3", "adalora"], + }, + ) + feedforward_modules: List[str] = field( + default_factory=lambda: None, + metadata={"help": "Target feedforward modules for the IA3 method."}, + ) PROMPT_DICT = { @@ -674,14 +711,38 @@ def compute_metrics(eval_preds): if training_args.do_train or training_args.do_eval: # PEFT settings - peft_config = LoraConfig( - r=finetune_args.lora_rank, - lora_alpha=finetune_args.lora_alpha, - lora_dropout=finetune_args.lora_dropout, - target_modules=finetune_args.lora_target_modules, - bias="none", - task_type=TaskType.CAUSAL_LM, - ) + if finetune_args.peft_type == "lora": + peft_config = LoraConfig( + r=finetune_args.lora_rank, + lora_alpha=finetune_args.lora_alpha, + lora_dropout=finetune_args.lora_dropout, + target_modules=finetune_args.target_modules, + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + elif finetune_args.peft_type == "adalora": + peft_config = AdaLoraConfig( + init_r=finetune_args.adalora_init_r, + target_r=finetune_args.adalora_target_r, + tinit=finetune_args.adalora_tinit, + tfinal=finetune_args.adalora_tfinal, + deltaT=finetune_args.adalora_delta_t, + lora_alpha=finetune_args.lora_alpha, + lora_dropout=finetune_args.lora_dropout, + target_modules=finetune_args.target_modules, + orth_reg_weight=finetune_args.adalora_orth_reg_weight, + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + from optimum.habana.peft.layer import GaudiAdaloraLayerSVDLinearForward + + tuners.adalora.layer.SVDLinear.forward = GaudiAdaloraLayerSVDLinearForward + elif finetune_args.peft_type == "ia3": + peft_config = IA3Config( + target_modules=finetune_args.target_modules, + feedforward_modules=finetune_args.feedforward_modules, + task_type=TaskType.CAUSAL_LM, + ) if training_args.gradient_checkpointing: model.enable_input_require_grads() if training_args.torch_compile: @@ -689,7 +750,7 @@ def compute_metrics(eval_preds): tuners.lora.layer.Linear.forward = GaudiLoraLayerLinearForward lora_model = get_peft_model(model, peft_config) - if training_args.bf16: + if training_args.bf16 and finetune_args.peft_type != "ia3": lora_model = lora_model.to(torch.bfloat16) lora_model.print_trainable_parameters() gaudi_config = GaudiConfig() diff --git a/optimum/habana/peft/__init__.py b/optimum/habana/peft/__init__.py index 2ba5892ad3..93e662169a 100644 --- a/optimum/habana/peft/__init__.py +++ b/optimum/habana/peft/__init__.py @@ -1 +1 @@ -from .layer import GaudiLoraLayerLinearForward +from .layer import GaudiAdaloraLayerSVDLinearForward, GaudiLoraLayerLinearForward diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py index f2e7561792..c7cec6ce42 100644 --- a/optimum/habana/peft/layer.py +++ b/optimum/habana/peft/layer.py @@ -29,3 +29,33 @@ def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any result = result.to(previous_dtype) return result + + +def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Copied from SVDLinear.forward: https://github.com/huggingface/peft/blob/v0.9.0/src/peft/tuners/adalora/layer.py#L158 + The only differences are: + - fix batch_gemm failure for BF16 case + """ + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + x = x.to(lora_A.dtype) + result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * (scaling / ranknum) + + return result diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 274ca2a646..55df6f07dd 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -116,6 +116,7 @@ if is_peft_available(): from peft import PeftModel + from peft.utils import PeftType if is_deepspeed_available(): @@ -849,6 +850,11 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): hb_profiler.start() total_batched_samples = 0 + if ( + _is_peft_model(self.model) + and self.model.peft_config[self.trainable_adapter_name].peft_type == PeftType.ADALORA + ): + self.model.base_model.peft_config[self.trainable_adapter_name].total_step = max_steps for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader if hasattr(epoch_iterator, "set_epoch"): @@ -990,7 +996,11 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() - + if ( + _is_peft_model(self.model) + and self.model.peft_config[self.trainable_adapter_name].peft_type == PeftType.ADALORA + ): + self.model.base_model.update_and_allocate(self.state.global_step) self._zero_model_grad(model) self.state.global_step += 1