-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PPO training. #7305
Add PPO training. #7305
Conversation
Thanks for your contribution! |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## develop #7305 +/- ##
===========================================
+ Coverage 56.68% 56.70% +0.02%
===========================================
Files 588 588
Lines 89243 89305 +62
===========================================
+ Hits 50584 50639 +55
- Misses 38659 38666 +7 ☔ View full report in Codecov by Sentry. |
相较于Beaver(DeepSpeed)这里PPOTrainer的代码实现会更复杂一些,主要是要从粗粒度的Trainer.train中拷贝抽离出来囊括完整单步训练(forward+backward+opt.step)的处理逻辑代码( 如果Trainer有更细粒度的包括forward、backward、opt.step方法提供出来,对于PPOTrainer以及其他训练逻辑复杂一些的算法实现可能会更容易些,类似于下面Beaver使用DeepSpeed engine的情况(actor_model和reward_critic_model都是DeepSpeed engine) https://github.com/PKU-Alignment/safe-rlhf/blob/main/safe_rlhf/algorithms/ppo/trainer.py#L171 |
Fix AutoModelForScore and update reward training usage.
@@ -1214,7 +1217,7 @@ def __init__(self, config: LlamaConfig): | |||
|
|||
# Recompute defaults to False and is controlled by Trainer | |||
self.enable_recompute = False | |||
if config.tensor_parallel_degree > 1: | |||
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以初始化的时候 搞掉吗? 我记得embeding那边有判断。
@wj-Mcat 看Test CI挂了,需要多卡运行( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Add PPO training.