Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoRA Implementation #9562

Merged
merged 11 commits into from
Dec 18, 2024
Merged

MoRA Implementation #9562

merged 11 commits into from
Dec 18, 2024

Conversation

lcykww
Copy link
Contributor

@lcykww lcykww commented Dec 4, 2024

PR types

Others

PR changes

Others

Description

修改了paddlenlp/peft/lora/lora_layers.py,添加了mora方法的实现
修改了paddlenlp/peft/lora/lora_model.py,主要添加了mora冻结参数的逻辑
修改了paddlenlp/peft/lora/lora_config.py,加入了use_mora参数
在paddlenlp/trl/model_config.py中加入了use_mora参数
在llm/run_finetune.py中调用的lora_config添加了use_mora参数

测试效果如下:
测试模型:facebook/llama-7b
训练集:commonsense_170k
其它参数使用paddlenlp的默认值
image

Copy link

paddle-bot bot commented Dec 4, 2024

Thanks for your contribution!

Copy link

codecov bot commented Dec 4, 2024

Codecov Report

Attention: Patch coverage is 90.32258% with 9 lines in your changes missing coverage. Please review.

Project coverage is 52.68%. Comparing base (407b3e6) to head (6181b2b).
Report is 5 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/peft/lora/lora_layers.py 89.88% 9 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9562      +/-   ##
===========================================
+ Coverage    52.66%   52.68%   +0.02%     
===========================================
  Files          712      712              
  Lines       111691   111762      +71     
===========================================
+ Hits         58818    58886      +68     
- Misses       52873    52876       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

elif "lora" in name:
weight.stop_gradient = False
if layer.use_mora:
if self.lora_config.trainable_bias in ["lora_A", "all"] and "bias" in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要加lora_A?

@@ -0,0 +1,113 @@
lora:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把lora改成mora吧

self.disable_static()
paddle.set_default_dtype("float32")

lora_config = load_test_config(self.config_path, "lora", self.model_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把lora改成mora

if isinstance(layer, paddle.nn.Linear) or isinstance(layer, QuantizationLinear):
weight_process(name, quant_config, lora_config, model_state_dict, args.device)

# 待修改
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -65,11 +65,13 @@ def __init__(
lora_plus_scale: float = 1.0,
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# create RoPE
if self.cos is None or self.sin is None:
inv_freq = 1.0 / (10000 ** (paddle.arange(0, r, 2, dtype=self._dtype) / r))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议用float32计算


# apply RoPE rotation
rh_in_x = paddle.concat([-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1)
# rh_in_x = paddle.cast(rh_in_x, dtype=paddle.paddle.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉

rh_in_x = paddle.concat([-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1)
# rh_in_x = paddle.cast(rh_in_x, dtype=paddle.paddle.bfloat16)
in_x = in_x * self.cos + rh_in_x * self.sin
in_x = paddle.cast(in_x, dtype=self._dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么还需要cast

rb2 = self.out_features // r if self.out_features % r == 0 else self.out_features // r + 1

# create RoPE
if self.cos is None or self.sin is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以写个函数初始化cos sin,然后在init调用一下,不要写太多处

@lugimzzz
Copy link
Contributor

在Description贴一下复现结果

@@ -3,7 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# mat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
)
self.cos = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里初始化就行self.RoPE_init(r, rb1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

out = (weight + lora_A @ lora_AB @ lora_B * scaling).cpu()
else:
out = (weight + lora_A @ lora_B * scaling).cpu()
delta_weight = layer.get_delta_weight()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么写有问题,调用的是layer.lora_A不是前面的lora_A.layer中的lora_A还是在cpu上的,一方面cpu上计算太慢,另一方面cpu不支持bfloat16。建议这里面重写一个get_delta_weight,或把get_delta_weight写的适合外部调用,例如

def get_delta_weight(self, lora_A=None,lora_B=None):
   if use_mora:
       lora_A = lora_A if lora_A is not None else self.lora_A
       ....
   else:
       lora_A = lora_A if lora_A is not None else self.lora_A
       lora_B = lora_B if lora_B is not None else self.lora_B
       ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -144,21 +157,102 @@ def pissa_init(self, rank):
weight = res.astype(dtype)
self.weight.set_value(weight)

def RoPE_init(self, r, rb1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好都是小写比如def rope_init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

lugimzzz
lugimzzz previously approved these changes Dec 18, 2024
Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@lugimzzz lugimzzz merged commit 90bc68e into PaddlePaddle:develop Dec 18, 2024
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants