Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vera-pissa method added #8722

Merged
merged 11 commits into from
Jul 23, 2024
32 changes: 32 additions & 0 deletions llm/config/llama/vera_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "facebook/llama-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/vera_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 1,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 10,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"vera": true,
"zero_padding": false,
"use_flash_attention": false
}
48 changes: 38 additions & 10 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import json
import os
import sys
import inspect
from functools import partial

import paddle
Expand Down Expand Up @@ -42,7 +41,14 @@
load_dataset,
)
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
from paddlenlp.peft import (
LoRAConfig,
LoRAModel,
PrefixConfig,
PrefixModelForCausalLM,
VeRAConfig,
VeRAModel,
)
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint
from paddlenlp.trainer.trainer_callback import TrainerState
from paddlenlp.transformers import (
Expand All @@ -51,9 +57,9 @@
AutoModelForCausalLMPipe,
AutoTokenizer,
Llama3Tokenizer,
LlamaTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
LlamaTokenizer,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -82,7 +88,6 @@ def main():
raise ValueError(
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
)


# Setup GPU & distributed training
paddle.set_device(training_args.device)
Expand Down Expand Up @@ -167,9 +172,7 @@ def main():
model = model_class.from_config(model_config, dtype=dtype)

if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
logger.warning("`flash_mask` must use with zero padding and flash attention.")
data_args.zero_padding = True
model.config.use_flash_attention = True

Expand Down Expand Up @@ -345,12 +348,16 @@ def neft_post_hook(module, input, output):
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM, QWen and Mistral so far."
)
train_ds = (
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
train_ds.map(
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
)
if train_ds is not None
else None
)
ptq_ds = (
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
ptq_ds.map(
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
)
if ptq_ds is not None
else None
)
Expand All @@ -361,7 +368,14 @@ def neft_post_hook(module, input, output):
)
eval_zero_padding = False
dev_ds = (
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask))
dev_ds.map(
partial(
trans_func,
is_test=data_args.eval_with_do_generation,
zero_padding=eval_zero_padding,
flash_mask=model_args.flash_mask,
)
)
if dev_ds is not None
else None
)
Expand Down Expand Up @@ -485,6 +499,20 @@ def compute_metrics_do_generation(eval_preds):
"bleu4": bleu4.score(),
}

if model_args.vera:
target_modules = get_lora_target_modules(model)
vera_config = VeRAConfig(
target_modules=target_modules,
r=model_args.vera_rank,
vera_alpha=model_args.vera_rank,
dtype=dtype,
base_model_name_or_path=model_args.model_name_or_path,
pissa_init=True,
)
model = VeRAModel(model, vera_config)
model.mark_only_vera_as_trainable(notfreezeB=True)
model.print_trainable_parameters()

# Create trainer
max_length = (
data_args.max_length
Expand Down
103 changes: 103 additions & 0 deletions llm/tools/merge_vera_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否验证过merge后的模型正确性?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证过,用merge后的模型可以正确预测。 done

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import paddle

from paddlenlp.peft import VeRAConfig, VeRAModel
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.utils.env import CONFIG_NAME


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
parser.add_argument("--vera_path", default="", help="The directory of VeRA parameters. Default to None")
parser.add_argument(
"--merge_vera_model_path",
default="",
help="The directory of merged parameters. Default to None",
)
parser.add_argument("--device", type=str, default="gpu", help="Device")
parser.add_argument(
"--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False"
)
return parser.parse_args()


def weight_process(name, vera_config, state_dict):
weight = state_dict.pop(name + ".weight").cuda()
vera_A = state_dict.pop(name + ".vera_A").cuda()
vera_B = state_dict.pop(name + ".vera_B").cuda()
vera_b = state_dict.pop(name + ".vera_b").cuda()
vera_d = state_dict.pop(name + ".vera_d").cuda()
diag_b = paddle.diag(vera_b)
diag_d = paddle.diag(vera_d)

scaling = vera_config.vera_alpha / vera_config.r
state_dict[name + ".weight"] = (weight + vera_A @ diag_d @ vera_B @ diag_b * scaling).cpu()


def merge():
args = parse_arguments()
paddle.set_device(args.device)

vera_config = VeRAConfig.from_pretrained(args.vera_path)
if vera_config.base_model_name_or_path is None:
if args.model_name_or_path is not None:
raise ValueError("We can not find a valid model_name_or_path.")
else:
vera_config.base_model_name_or_path = args.model_name_or_path

if os.path.isfile(os.path.join(args.vera_path, CONFIG_NAME)):
config = AutoConfig.from_pretrained(args.vera_path)
elif args.model_name_or_path is not None:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
raise ValueError(
f"We can not find config.json in vera_path: {args.vera_path} or find a valid model_name_or_path."
)
config.dtype = vera_config.dtype
if (
vera_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
) and args.device == "cpu":
raise ValueError("We can not apply bfloat16 or nf4/fp4 vera merge on cpu.")

# with device_guard() will cause SVD decomposition to fail
model = AutoModelForCausalLM.from_pretrained(
vera_config.base_model_name_or_path,
config=config,
low_cpu_mem_usage=True,
)
model = VeRAModel.from_pretrained(model=model, vera_path=args.vera_path, vera_config=vera_config)

model.eval()
model_state_dict = model.model.state_dict()
vera_name_list = []
for key in model_state_dict.keys():
if "vera_A" in key:
vera_name_list.append(key[:-7])

for name in vera_name_list:
weight_process(name, vera_config, model_state_dict)

model.model.save_pretrained(args.merge_vera_model_path, state_dict=model_state_dict)
tokenizer = AutoTokenizer.from_pretrained(vera_config.base_model_name_or_path)
tokenizer.save_pretrained(args.merge_vera_model_path)


if __name__ == "__main__":
merge()
8 changes: 5 additions & 3 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class ModelArgument:
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})

# vera related parameters
vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"})
vera_rank: int = field(default=8, metadata={"help": "Vera attention dimension"})

# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
Expand All @@ -209,9 +213,7 @@ class ModelArgument:
aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"})
neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"})
neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"})
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash_mask in flash attention."}
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash_mask in flash attention."})


@dataclass
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .lora import LoRAConfig, LoRAModel
from .prefix import PrefixConfig, PrefixModelForCausalLM
from .vera import VeRAConfig, VeRAModel
17 changes: 17 additions & 0 deletions paddlenlp/peft/vera/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .vera_config import VeRAConfig
from .vera_layers import VeRALinear
from .vera_model import VeRAModel
Loading
Loading