refactor(sunjx): refactor loss-filter implementation#17
refactor(sunjx): refactor loss-filter implementation#17Jiaxuan-Sun wants to merge 10 commits intoopendilab:mainfrom
Conversation
| ret = {} | ||
| for k in all_keys: | ||
| ret[k] = self.all_reduce(data.get(k, 0.0), op) | ||
| return ret |
There was a problem hiding this comment.
Why was this added? Does it cause an error without it?
There was a problem hiding this comment.
This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.
lightrft/trainer/fast_exp_maker.py
Outdated
| for exp in chunk: | ||
| exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) | ||
| if config.dynamic_sampling and not use_dynamic_filter: | ||
| # Legacy dynamic sampling (only if not using filter_weight framework) |
There was a problem hiding this comment.
Why has the dynamic sampling logic become so complex? The previous implementation seemed much simpler/cleaner. Could you explain the reasoning behind this change?
There was a problem hiding this comment.
The previous version could cause deadlocks, hence the modification.
| # If no valid actions or base log-probs are empty, skip KL safely. | ||
| if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0) | ||
| or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): | ||
| kl = torch.zeros_like( |
There was a problem hiding this comment.
Have these null-check branches actually been hit during testing? If it's null, we should probably just throw an error directly.
There was a problem hiding this comment.
Yes, an error occurred where a dimension mismatch was reported due to an action_mask value of 0 or baseline logprobs being empty (entering compute_approx_kl when base_action_log_probs was empty), indicating that these branches are actually triggered in such dynamic sampling and filtering scenarios.
Add new
lightrft/trainer/filter_weight/module with:metrics.py- Metrics computation layer (entropy, difficulty, staleness, etc.)filters.py- Sample filtering layer (length, reward value, entropy, difficulty filters, etc.)weights.py- Loss weighting layer (length, entropy, difficulty, staleness weightings, etc.)manager.py- Unified management layer (FilterWeightManager)Note: The dynamic sampling feature has been tested. Other components are reserved for future extension.