feature(luyd): add partial rollout in training process#29
feature(luyd): add partial rollout in training process#29AltmanD wants to merge 5 commits intoopendilab:mainfrom
Conversation
lightrft/trainer/spmd_ppo_trainer.py
Outdated
| """ | ||
| Process a batch of experiences: add to replay buffer, train, and update metrics. | ||
|
|
||
| Args: |
There was a problem hiding this comment.
use param type style format
lightrft/trainer/spmd_ppo_trainer.py
Outdated
| # Then initialize our base class | ||
| assert "processor" in kwargs and kwargs["processor"] is not None, "processor is required for SPMDPPOTrainerVL" | ||
| SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs) | ||
| if getattr(self.args, 'use_partial', False): |
There was a problem hiding this comment.
You can directly use self.args.use_partial, which defaults to False.
| This subclass adds two key features: | ||
| 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated | ||
| in each call; the rest is kept in buffers. | ||
| 2. Token‑budget regeneration: samples whose generation reaches max_token_budget are flagged |
There was a problem hiding this comment.
introduce the meanning of --partial_percent and --max_budget, add the overview of our implementaion of partial_rollout to the top of this file and the pr description.
There was a problem hiding this comment.
add k1.5 and mimo reference
There was a problem hiding this comment.
add some monitoring metrics of partial_rollout to log and wanbd in order to debugging and analysing.
|
|
||
| @torch.no_grad() | ||
| def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: | ||
| """Regenerate outputs for samples that reached token budget.""" |
There was a problem hiding this comment.
Prefix Caching or reuse Session/Request ID mechanism (not prioritized now)
There was a problem hiding this comment.
(在其他项完成之后,如果性能影响明显,可以考虑这个机制)
引入 Staleness(陈旧度)阈值机制
为避免 self.regen_buffer 中滞留的样本因策略版本过旧(Staleness 过高)而引发 Off-policy 训练不稳定性,建议设置一个陈旧度阈值。
- 机制:当样本滞留时间超过阈值时,强制将其丢弃或优先在下一轮完成生成。
- 目的:确保训练数据与当前模型策略保持一致,减少分布偏差,提升收敛稳定性。
Implement a Staleness Threshold Mechanism (not prioritized now)
To prevent off-policy instability caused by outdated samples lingering in self.regen_buffer, we recommend enforcing a staleness threshold.
- Mechanism: If a sample's staleness exceeds the limit, it must be either discarded or prioritized for immediate completion in the next round.
- Goal: This ensures data remains consistent with the current policy, minimizing distribution shift and improving training stability.
| """ | ||
| args = self.strategy.args | ||
| is_multimodal = all_images is not None | ||
| internvl = "internvl" in self.actor.pretrain_or_model.lower() if is_multimodal else False |
There was a problem hiding this comment.
delete internvl related, add partial_rollout functionn from the latest lightrft/trainer/fast_exp_maker.py
No description provided.