Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PPO training. #7305

Merged
merged 20 commits into from
Jan 22, 2024
Merged

Add PPO training. #7305

merged 20 commits into from
Jan 22, 2024

Conversation

guoshengCS
Copy link
Contributor

@guoshengCS guoshengCS commented Oct 24, 2023

PR types

New features

PR changes

Others

Description

Add PPO training.

@paddle-bot
Copy link

paddle-bot bot commented Oct 24, 2023

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Oct 24, 2023

Codecov Report

Attention: 35 lines in your changes are missing coverage. Please review.

Comparison is base (16d3c49) 56.68% compared to head (ec150b6) 56.70%.
Report is 1 commits behind head on develop.

Files Patch % Lines
paddlenlp/generation/utils.py 46.66% 32 Missing ⚠️
paddlenlp/transformers/llama/modeling.py 57.14% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7305      +/-   ##
===========================================
+ Coverage    56.68%   56.70%   +0.02%     
===========================================
  Files          588      588              
  Lines        89243    89305      +62     
===========================================
+ Hits         50584    50639      +55     
- Misses       38659    38666       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@guoshengCS
Copy link
Contributor Author

guoshengCS commented Dec 27, 2023

相较于Beaver(DeepSpeed)这里PPOTrainer的代码实现会更复杂一些,主要是要从粗粒度的Trainer.train中拷贝抽离出来囊括完整单步训练(forward+backward+opt.step)的处理逻辑代码(full_train_step),另外一些是为了和Trainer尽可能功能和表现对齐以及复用有一些代码拷贝和适配。

如果Trainer有更细粒度的包括forward、backward、opt.step方法提供出来,对于PPOTrainer以及其他训练逻辑复杂一些的算法实现可能会更容易些,类似于下面Beaver使用DeepSpeed engine的情况(actor_model和reward_critic_model都是DeepSpeed engine) https://github.com/PKU-Alignment/safe-rlhf/blob/main/safe_rlhf/algorithms/ppo/trainer.py#L171
image

@guoshengCS guoshengCS marked this pull request as ready for review January 2, 2024 02:13
@guoshengCS guoshengCS changed the title Add PPO reward model and training. Add PPO training. Jan 8, 2024
@@ -1214,7 +1217,7 @@ def __init__(self, config: LlamaConfig):

# Recompute defaults to False and is controlled by Trainer
self.enable_recompute = False
if config.tensor_parallel_degree > 1:
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以初始化的时候 搞掉吗? 我记得embeding那边有判断。

@ZHUI ZHUI self-requested a review January 15, 2024 03:06
@guoshengCS
Copy link
Contributor Author

guoshengCS commented Jan 19, 2024

@wj-Mcat 看Test CI挂了,需要多卡运行(python -m paddle.distributed.launch)运行的测试要如何加入了
image

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit d4de12c into PaddlePaddle:develop Jan 22, 2024
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants