Skip to content

Comments

refactor(sunjx): refactor loss-filter implementation#17

Open
Jiaxuan-Sun wants to merge 10 commits intoopendilab:mainfrom
Jiaxuan-Sun:refactor/loss-filter
Open

refactor(sunjx): refactor loss-filter implementation#17
Jiaxuan-Sun wants to merge 10 commits intoopendilab:mainfrom
Jiaxuan-Sun:refactor/loss-filter

Conversation

@Jiaxuan-Sun
Copy link
Contributor

@Jiaxuan-Sun Jiaxuan-Sun commented Jan 1, 2026

Add new lightrft/trainer/filter_weight/ module with:

  • metrics.py - Metrics computation layer (entropy, difficulty, staleness, etc.)
  • filters.py - Sample filtering layer (length, reward value, entropy, difficulty filters, etc.)
  • weights.py - Loss weighting layer (length, entropy, difficulty, staleness weightings, etc.)
  • manager.py - Unified management layer (FilterWeightManager)

Note:​ The dynamic sampling feature has been tested. Other components are reserved for future extension.

@puyuan1996 puyuan1996 added enhancement New feature or request refactor Cleanup, formatting, or restructuring of existing code. labels Jan 4, 2026
ret = {}
for k in all_keys:
ret[k] = self.all_reduce(data.get(k, 0.0), op)
return ret
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added? Does it cause an error without it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.

for exp in chunk:
exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool)
if config.dynamic_sampling and not use_dynamic_filter:
# Legacy dynamic sampling (only if not using filter_weight framework)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has the dynamic sampling logic become so complex? The previous implementation seemed much simpler/cleaner. Could you explain the reasoning behind this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version could cause deadlocks, hence the modification.

# If no valid actions or base log-probs are empty, skip KL safely.
if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0)
or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)):
kl = torch.zeros_like(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have these null-check branches actually been hit during testing? If it's null, we should probably just throw an error directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, an error occurred where a dimension mismatch was reported due to an action_mask value of 0 or baseline logprobs being empty (entering compute_approx_kl when base_action_log_probs was empty), indicating that these branches are actually triggered in such dynamic sampling and filtering scenarios.

@puyuan1996 puyuan1996 mentioned this pull request Jan 21, 2026
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request refactor Cleanup, formatting, or restructuring of existing code.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants