Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemention of lqlora #8820

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open

implemention of lqlora #8820

wants to merge 6 commits into from

Conversation

Liebele
Copy link

@Liebele Liebele commented Jul 28, 2024

使用paddle中LoRA支持的量化算法进行混合量化。
使用LQ-LoRA,需要先获取整数线性规划的准备数据,接着使用整数线性规划求解每个矩阵对应的量化算法,然后进行迭代初始化获取修改后的原模型及LoRA模块的参数。
在使用LQ-LoRA进行微调时,将“weight_quantize_algo”设置为“lqlora”,同时提供“lqlora_quantize_cfg”以及“lqlora_state_dict”(分别对应每个矩阵对应的量化算法、迭代初始化后的原模型及LoRA模块的参数)。

Copy link

paddle-bot bot commented Jul 28, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 28, 2024

CLA assistant check
All committers have signed the CLA.

@Liebele
Copy link
Author

Liebele commented Jul 28, 2024

our1方法同时使用 “lqlora_quantize_cfg”以及“lqlora_state_dict”
our2 方法仅使用 “lqlora_quantize_cfg”
budget:6bit
model:llama2-7B
在8个NLU数据集上的实验结果:
image
在2个NLG数据集上的实验结果:
image
显存占用情况:
image

Copy link

codecov bot commented Jul 28, 2024

Codecov Report

Attention: Patch coverage is 5.42636% with 122 lines in your changes missing coverage. Please review.

Project coverage is 53.03%. Comparing base (5ad7a9c) to head (28fd256).
Report is 268 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/peft/lora/lqlora_utils.py 0.00% 89 Missing ⚠️
paddlenlp/quantization/quantization_utils.py 0.00% 22 Missing ⚠️
paddlenlp/transformers/model_utils.py 35.29% 11 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8820      +/-   ##
===========================================
+ Coverage    52.99%   53.03%   +0.03%     
===========================================
  Files          671      658      -13     
  Lines       109835   106606    -3229     
===========================================
- Hits         58212    56543    -1669     
+ Misses       51623    50063    -1560     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


def get_lqlora_state_dict():
args = parse_arguments()
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该固定dtype的类型,参考run_fintune.py,建议不要用to(dtype),使用from_pretrained(dtype=dtype)

target_modules = get_lora_target_modules(model)
lora_config = LoRAConfig(
target_modules=target_modules,
r=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?为什么把config写死,参考run_finetune.py


state_dict = model.state_dict()
paddle.save(state_dict, args.output_path)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没看懂这个脚本想要干什么?初始化lqlora为什么要单独写一个脚本,存储这个权重?


model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
for name, submodule in model.named_sublayers():
if "_proj" in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用_proj筛选的方式不适用于所有模型,建议用lora_target_module和判断是不是nn.linear的方式

qconfigs = ilp_data["qconfigs"]

normalized_costs = costs / paddle.linalg.norm(costs) * 1000.0
normalized_budget = args.budget / GIGABYTES * num_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里budget的对应含义是什么

normalized_budget = args.budget / GIGABYTES * num_params
normalized_weights = weights / GIGABYTES
assignments_cost, assignments = compute_qconfig_assignments(
budget=normalized_budget, costs=normalized_costs, weights=normalized_weights, num_chunks=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个num_chunks设为1不就是一个参数搜索了一次这意义是什么?

]
):
raise ValueError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实我是建议不要把lqlora lora_A和lora_初始化和搜索quant_algo的功能写成一个脚本的形式,然后保存state_dict,可以考虑写在loramodel的初始化中,具体可以PEFT中loftq的写法https://github.com/huggingface/peft/blob/8f3970865079ca1ca1a406cc9f3b3870d677dfb4/src/peft/utils/loftq_utils.py#L333

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

具体来说就是首先模型加载的时候,加载的还是一个16bit参数的模型,然后在LoRAConfig中设置一个lqconfig用于参数传入。如果只是单纯的lqlora的话,用就把原先的nn.linear替换为对应quant_algo的quantizationLoRALinear;如果加入搜索,那么先进行搜索,得到每个层对应的quant_algo,然后再进行quantizationLoRALinear提换。设计的时候要考虑初始化、保存、热启、参数合并的场景

@@ -564,6 +564,8 @@ def __init__(self, **kwargs):
if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict):
kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"])
self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig())
self.lqlora_quantize_cfg = kwargs.pop("lqlora_quantize_cfg", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块from_pretrained通过指定不同层不同量化策略的逻辑也可以保留,但就不要和loftq绑定了,可以新增model_config.quantization_config.quantize_cfg,指定不同层的量化策略

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers这里的逻辑和上面loftq是分离开的,建议这部分另起一个PR

Copy link

This Pull Request is stale because it has been open for 60 days with no activity. 当前Pull Request 60天内无活动,被标记为stale。

@github-actions github-actions bot added the stale label Dec 15, 2024
@github-actions github-actions bot removed the stale label Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants