Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemention of lqlora #8820

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
implement of lqlora
  • Loading branch information
Liebele committed Jul 28, 2024
commit 9f76236c238be159698b1fb2b94aed937c24cbcc
35 changes: 35 additions & 0 deletions llm/config/llama/lqlora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"model_name_or_path": "facebook/llama-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/llama_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false,
"weight_quantize_algo": "lqlora",
"lqlora_quantize_cfg": "path of lqlora quantize config",
"lqlora_state_dict": "path of lqlora state dict"
}
36 changes: 20 additions & 16 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@
load_dataset,
)
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import (
from paddlenlp.peft import ( # VeRAConfig,; VeRAModel,
LoRAConfig,
LoRAModel,
PrefixConfig,
PrefixModelForCausalLM,
VeRAConfig,
VeRAModel,
)
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint
from paddlenlp.trainer.trainer_callback import TrainerState
Expand Down Expand Up @@ -151,6 +149,11 @@ def main():
if model_args.fuse_attention_ffn is not None:
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn

if model_args.lqlora_quantize_cfg is not None:
model_config.lqlora_quantize_cfg = model_args.lqlora_quantize_cfg
if model_args.lqlora_state_dict is not None:
model_config.lqlora_state_dict = model_args.lqlora_state_dict

model_config.seq_length = data_args.max_length

logger.info(f"Final model config: {model_config}")
Expand Down Expand Up @@ -463,6 +466,7 @@ def neft_post_hook(module, input, output):
do_qat=quant_args.do_qat,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
lqlora_state_dict=model_args.lqlora_state_dict,
)
model = LoRAModel(model, lora_config)
else:
Expand Down Expand Up @@ -500,19 +504,19 @@ def compute_metrics_do_generation(eval_preds):
"bleu4": bleu4.score(),
}

if model_args.vera:
target_modules = get_lora_target_modules(model)
vera_config = VeRAConfig(
target_modules=target_modules,
r=model_args.vera_rank,
vera_alpha=model_args.vera_rank,
dtype=dtype,
base_model_name_or_path=model_args.model_name_or_path,
pissa_init=True,
)
model = VeRAModel(model, vera_config)
model.mark_only_vera_as_trainable(notfreezeB=True)
model.print_trainable_parameters()
# if model_args.vera:
# target_modules = get_lora_target_modules(model)
# vera_config = VeRAConfig(
# target_modules=target_modules,
# r=model_args.vera_rank,
# vera_alpha=model_args.vera_rank,
# dtype=dtype,
# base_model_name_or_path=model_args.model_name_or_path,
# pissa_init=True,
# )
# model = VeRAModel(model, vera_config)
# model.mark_only_vera_as_trainable(notfreezeB=True)
# model.print_trainable_parameters()

# Create trainer
max_length = (
Expand Down
185 changes: 185 additions & 0 deletions llm/tools/get_lqlora_quantize_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from typing import List, Optional, Tuple

import gurobipy as gp
import numpy as np
import paddle
import scipy.optimize._optimize as scipy_optimize

from paddlenlp.peft.lora.lqlora_utils import lowrand_quantized_sparse_decomposition
from paddlenlp.transformers import AutoModelForCausalLM


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, required=True, type=str, help="The directory of model.")
parser.add_argument("--qconfigs", default=None, type=str, required=True, help="Quantize methods, use ',' split")
parser.add_argument("--budget", default=None, type=float, required=True, help="Budget")
parser.add_argument("--ranks", default=64, type=int, help="SVD rank")
parser.add_argument("--output_path", default=None, type=str, required=True, help="The directory of saved model ")
return parser.parse_args()


def estimate_storage_from_config(W: paddle.Tensor, quant_algo: str):
if quant_algo in ["weight_only_int8", "llm.int8"]:
return W.numel() * 8.0
elif quant_algo in ["weight_only_int4", "fp4"]:
return W.numel() * 4.0
elif quant_algo in ["nf4"]:
return W.numel() * 4.127
else:
raise NotImplementedError(f"{quant_algo} is not support.")


def prepare_data_for_qconfig(
names: List[str],
parameters: List[paddle.Tensor],
qconfigs: List[str],
num_ranks: Optional[int] = 64,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
costs = paddle.zeros(shape=[len(parameters), len(qconfigs)])
weights = paddle.zeros(shape=[len(parameters), len(qconfigs)])

for i0, param in enumerate(parameters):
for i1, qconfig in enumerate(qconfigs):
print(f"process {names[i0]}. quant_algo: {qconfig}")
Q, L1, L2 = lowrand_quantized_sparse_decomposition(param, num_ranks, quant_algo=qconfig)
param_ = L1 @ L2 + Q
error = paddle.linalg.norm(param - param_, p="fro") ** 2

costs[i0, i1] = error
weights[i0, i1] = estimate_storage_from_config(param, quant_algo=qconfig)
del param_
return costs, weights


def compute_qconfig_assignments(
budget: float,
costs: paddle.Tensor,
weights: paddle.Tensor,
num_chunks: int,
) -> Tuple[float, paddle.Tensor]:
costs_np = costs.numpy(force=True).reshape(costs.shape[0], -1)
weights_np = weights.numpy(force=True).reshape(weights.shape[0], -1)
costs_list = np.split(costs_np, indices_or_sections=num_chunks, axis=0)
weights_list = np.split(weights_np, indices_or_sections=num_chunks, axis=0)

results = []
for _costs, _weights in zip(costs_list, weights_list):
result = mip_solve(budget=budget / float(num_chunks), costs=_costs, weights=_weights, backend="grurobi")
results.append(result)

assignments_cost = sum([r.fun for r in results])
assignments = np.concatenate([r.x for r in results], axis=0)
assignments = assignments.reshape(costs.shape)
return assignments_cost, paddle.to_tensor(assignments)


def mip_solve(
budget: float,
costs: np.ndarray,
weights: np.ndarray,
backend: str,
) -> scipy_optimize.OptimizeResult:
if backend not in ["scipy", "grurobi"]:
raise ValueError(f"Unknown backend: {backend}")

N = costs.shape[0]
coefficients = costs.reshape(-1)
A_upperbound = weights.reshape(1, -1)
A_equality = np.zeros_like(weights, shape=(N,) + weights.shape)
A_equality[np.arange(N), np.arange(N), :] = 1.0
A_equality = A_equality.reshape(N, -1)

if backend == "grurobi":
grurobi_model = gp.Model()
grurobi_model.setParam(paramname="Timelimit", newval=60) # type: ignore
x = grurobi_model.addMVar(shape=coefficients.shape, vtype=gp.GRB.BINARY, name="x")
grurobi_model.setObjective(coefficients @ x, gp.GRB.MINIMIZE)
grurobi_model.addConstr((A_upperbound @ x) <= budget, name="upperbounds")
grurobi_model.addConstr((A_equality @ x) == 1.0, name="equalities")
grurobi_model.optimize()
return scipy_optimize.OptimizeResult(x=x.X, fun=grurobi_model.ObjVal)

raise ValueError


def get_ilp_data(args):
names = []
params = []
qconfigs = args.qconfigs.split(",")
num_ranks = args.ranks

model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
for name, submodule in model.named_sublayers():
if "_proj" in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用_proj筛选的方式不适用于所有模型,建议用lora_target_module和判断是不是nn.linear的方式

names.append(name)
params.append(submodule.weight)

costs, weights = prepare_data_for_qconfig(names=names, parameters=params, qconfigs=qconfigs, num_ranks=num_ranks)
ilp_data = {
"names": names,
"shapes": [param.shape for param in params],
"nparams": sum([param.numel() for param in params]),
"costs": costs,
"weights": weights,
"qconfigs": qconfigs,
}

return ilp_data


def get_lqlora_quantize_cfg():
args = parse_arguments()
GIGABYTES = 1024.0**3

ilp_data = get_ilp_data(args)
costs = ilp_data["costs"]
weights = ilp_data["weights"]
num_params = ilp_data["nparams"]
names = ilp_data["names"]
qconfigs = ilp_data["qconfigs"]

normalized_costs = costs / paddle.linalg.norm(costs) * 1000.0
normalized_budget = args.budget / GIGABYTES * num_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里budget的对应含义是什么

normalized_weights = weights / GIGABYTES
assignments_cost, assignments = compute_qconfig_assignments(
budget=normalized_budget, costs=normalized_costs, weights=normalized_weights, num_chunks=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个num_chunks设为1不就是一个参数搜索了一次这意义是什么?

)

if not all(
[
costs.shape == [len(names), len(qconfigs)],
weights.shape == [len(names), len(qconfigs)],
assignments.shape == [len(names), len(qconfigs)],
]
):
raise ValueError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实我是建议不要把lqlora lora_A和lora_初始化和搜索quant_algo的功能写成一个脚本的形式,然后保存state_dict,可以考虑写在loramodel的初始化中,具体可以PEFT中loftq的写法https://github.com/huggingface/peft/blob/8f3970865079ca1ca1a406cc9f3b3870d677dfb4/src/peft/utils/loftq_utils.py#L333

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

具体来说就是首先模型加载的时候,加载的还是一个16bit参数的模型,然后在LoRAConfig中设置一个lqconfig用于参数传入。如果只是单纯的lqlora的话,用就把原先的nn.linear替换为对应quant_algo的quantizationLoRALinear;如果加入搜索,那么先进行搜索,得到每个层对应的quant_algo,然后再进行quantizationLoRALinear提换。设计的时候要考虑初始化、保存、热启、参数合并的场景

qconfig_dict = {}
for i0, (i1, qconfig_index) in enumerate(assignments.nonzero().tolist()):
if i0 != i1:
raise ValueError
key0 = names[i1]
qconfig_dict[key0] = qconfigs[qconfig_index]

paddle.save(qconfig_dict, os.path.join(args.output_path, "lqlora_quantize_cfg"))


if __name__ == "__main__":
get_lqlora_quantize_cfg()
63 changes: 63 additions & 0 deletions llm/tools/get_lqlora_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import sys

sys.path.append(os.path.dirname(os.getcwd()))

import paddle
from utils.utils import get_lora_target_modules

from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.peft.lora.lqlora_utils import transform_lora_layers
from paddlenlp.transformers import AutoModelForCausalLM


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, required=True, type=str, help="The directory of model.")
parser.add_argument(
"--lqlora_quantize_cfg", default=None, type=str, required=True, help="The directory of lqlora quantize config"
)
parser.add_argument("--output_path", default=None, type=str, required=True, help="The directory of saved model ")
return parser.parse_args()


def get_lqlora_state_dict():
args = parse_arguments()
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该固定dtype的类型,参考run_fintune.py,建议不要用to(dtype),使用from_pretrained(dtype=dtype)

model = model.to(dtype=paddle.float16)

target_modules = get_lora_target_modules(model)
lora_config = LoRAConfig(
target_modules=target_modules,
r=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?为什么把config写死,参考run_finetune.py

lora_alpha=16,
merge_weights=False,
tensor_parallel_degree=1,
dtype=paddle.float16,
base_model_name_or_path=args.model_name_or_path,
)
model = LoRAModel(model, lora_config)
lqlora_quantize_cfg = paddle.load(args.lqlora_quantize_cfg)
transform_lora_layers(model, lqlora_quantize_cfg)

state_dict = model.state_dict()
paddle.save(state_dict, args.output_path)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没看懂这个脚本想要干什么?初始化lqlora为什么要单独写一个脚本,存储这个权重?


if __name__ == "__main__":
get_lqlora_state_dict()
2 changes: 2 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class ModelArgument:
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
lqlora_quantize_cfg: str = field(default=None, metadata={"help": "Quantization algorithm for each matrix."})
lqlora_state_dict: str = field(default=None, metadata={"help": "Quantization model use lqlora quantize config."})

# vera related parameters
vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"})
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class LoRAConfig:
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
lqlora_state_dict: str = field(default=None, metadata={"help": "Initialize lq-lora state dict."})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,15 @@ def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: L
module_name = i[0]
if re.fullmatch(target_module, module_name):
self._find_and_replace_module(model, module_name, lora_config, enable_lora)

if lora_config.lqlora_state_dict is not None:
model = self.init_lqlora_model(model, lora_config)
return model

def init_lqlora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: LoRAConfig):
lqlora_state_dict = paddle.load(lora_config.lqlora_state_dict)
lqlora_state_dict = {k[6:]: v for k, v in lqlora_state_dict.items() if "lora_A" in k or "lora_B" in k}
model.set_state_dict(lqlora_state_dict)
return model

def restore_original_model(self):
Expand Down
Loading