-
Notifications
You must be signed in to change notification settings - Fork 2.1k
support GSPO-token #3820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
support GSPO-token #3820
Conversation
Thanks for this. Since GSPO-token is a generalized version of vanilla GSPO, I suggest we fully transition to GSPO-token instead of supporting both versions. Consequently, we would rename/remove |
trl/trainer/grpo_trainer.py
Outdated
elif self.importance_sampling_level == 'sequence_token': | ||
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] | ||
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) | ||
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient | |
seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like
if self.importance_sampling_level != 'token'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense, so shall we move the invalid value check for importance_sampling_level into the model parameter initialization?
trl/trainer/grpo_trainer.py
Outdated
elif self.importance_sampling_level == 'sequence_token': | ||
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] | ||
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) | ||
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like
if self.importance_sampling_level != 'token'
Agreed to keep GSPO-token. Should we retain this parameter for compatibility with previous usage, or introduce an additional parameter instead? Which is better? |
@qgallouedec @lewtun @edbeeching @kashif If there are any concerns or suggestions, please feel free to let me know. Thank you very much in advance |
imo it should be removed, however, since it's already been published as part of TRL v0.20, we may need to keep it for backward comp. I can't speak to it myself, so I'll leave it to someone else to decide. |
Support for GSPO-token as described in GSPO paper, Section 4.3.
related issue: #3811
GSPO
$w_{i}^{\mathrm{GSPO}} = \left[ \frac{\pi_{\theta}(y_i \mid x)}{\pi_{\theta_{\mathrm{old}}}(y_i \mid x)} \right]^{\frac{1}{|y_i|}} = \exp(\frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}}(y_{i, t} \mid x, y_{i, <t})})$
GSPO-token
$w_{i, t}^{\mathrm{GSPO_token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}$