Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vera-pissa method added #8722

Merged
merged 11 commits into from
Jul 23, 2024
Merged

Conversation

TranscenderNing
Copy link
Contributor

@TranscenderNing TranscenderNing commented Jul 5, 2024

PR types

New features

PR changes

Add vera-pissa in peft/vera

Description

根据review意见修改
vera-pissa

Copy link

paddle-bot bot commented Jul 5, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 5, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Jul 6, 2024

Codecov Report

Attention: Patch coverage is 80.70740% with 60 lines in your changes missing coverage. Please review.

Project coverage is 55.51%. Comparing base (d8ddba9) to head (d4810c1).
Report is 27 commits behind head on develop.

Files Patch % Lines
paddlenlp/peft/vera/vera_model.py 77.34% 41 Missing ⚠️
paddlenlp/peft/vera/vera_layers.py 77.94% 15 Missing ⚠️
paddlenlp/trainer/trainer.py 50.00% 3 Missing ⚠️
paddlenlp/trainer/integrations.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8722      +/-   ##
===========================================
- Coverage    55.73%   55.51%   -0.22%     
===========================================
  Files          623      630       +7     
  Lines        97464    98374     +910     
===========================================
+ Hits         54324    54616     +292     
- Misses       43140    43758     +618     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -0,0 +1,187 @@
out_features = 16 # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉out_features = 16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

isinstance(self.model, LoRAModel)
or isinstance(self.model, PrefixModelForCausalLM)
or isinstance(self.model, VeRAModel)
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 测试一下VeRAModel 重新加载和热启的时候能否正常使用
  • 重新加载就是训练的时候设置 load_best_model_at_end 为 True,看时候能够正常加载最好的checkpoint
  • 热启指的是训练过程中,output_dir中包含原有训练checkpoint,trainer可以启用resume_from_checkpoint去加载到最后一个checkpoint继续训练

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试可以重新加载 done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

适配了热启动,测试可以 done

self.model = self.get_vera_model(model, vera_config)
self.is_pipelinemodel = False
if issubclass(type(self.model), PipelineLayer):
self.is_pipelinemodel = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前vera也不支持pp,建议raise NotImplementedError("vera don't support pipeline parallel now")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

vera_model = cls(model, vera_config)

# define vera weight name
if vera_config_tensor_parallel_degree > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前不支持vera都可以先删除tensor_parallel_degree相关的分支

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

trainable_state_dict = OrderedDict()
for name, weight in self.model.state_dict().items():
# get vera parameter & QAT scale parameter
if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不支持quant相关,建议也把quant相关删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# freezeB=False, vera_b, vera_d 可训练
if "vera" in name:
weight.stop_gradient = False
elif "lora_B" in name and notfreezeB:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况会出现weight中含有lora?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前vera_model中参数名是lora_ 已经全部统一成vera_
done


def train(self):
super().train()
if self.merge_weights and self.merged:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_weight已经删除,新增为一个merge函数不再与train和eval耦合,可以参考这个prhttps://github.com//pull/8674/files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


else:
# Actual trainable parameters
self.lora_A = self.create_parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要叫lora,改为vera_A和vera_B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -0,0 +1,104 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否验证过merge后的模型正确性?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证过,用merge后的模型可以正确预测。 done

#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
vera_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora记得改成vera

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好整个pr扫一遍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

r: int = 0,
vera_alpha: int = 1,
vera_dropout: float = 0.0,
merge_weights: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉merge_weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if enable_vera is None:
if isinstance(module, nn.Linear):
vera_module = VeRALinear(
# 将要替换的层传递过去
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释用英文,不要中文

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

isinstance(vera_config.enable_vera_list, List)
and all(isinstance(item, bool) for item in vera_config.enable_vera_list)
):
enable_vera_list = [vera_config.enable_vera_list]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_vera_list 这个应该是直接复用lora的,vera并没有对应的功能,建议把enable_vera_list相关全部删除,走代码里为None的分支就好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是在vera_config层面就把enable_vera_list全部删除,因为我们不需要这个参数,我看现在代码还保留着?

@@ -111,82 +111,3 @@ def test_rslora_plus(self):
self.run_predictor({"inference_model": True})

self.run_predictor({"inference_model": False})


# @parameterized_class(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块不要删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

["baichuan"],
],
)
class VeraTest(LLMTest, unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本地测试一下是否能正常运行

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cd PaddleNLP
python -m pytest tests/llm/test_vera.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以正常运行 done

) and args.device == "cpu":
raise ValueError("We can not apply bfloat16 or nf4/fp4 vera merge on cpu.")

vera_config.merge_weights = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vera_config.merge_weights没有merge weight了,记得去掉,否则会报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@lugimzzz
Copy link
Contributor

image
单测覆盖率要增加不足要增加具体看detail提示

@@ -48,7 +48,6 @@ def __init__(
self.merged = False

if pissa_init:
assert self.vera_alpha == self.r, "pissa method requires vera_alpha=r, scaling=1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么把这个删除了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了增加代码的覆盖率,重新加回去了并添加相应的异常测试

isinstance(vera_config.enable_vera_list, List)
and all(isinstance(item, bool) for item in vera_config.enable_vera_list)
):
enable_vera_list = [vera_config.enable_vera_list]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是在vera_config层面就把enable_vera_list全部删除,因为我们不需要这个参数,我看现在代码还保留着?

@@ -0,0 +1,15 @@
{
"base_model_name_or_path": null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的作用是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试用的,已删除,done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已把vera_config层就把enable_vera_list全部删除

Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@lugimzzz lugimzzz merged commit de7d103 into PaddlePaddle:develop Jul 23, 2024
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants