Skip to content

support GSPO-token #3820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

support GSPO-token #3820

wants to merge 5 commits into from

Conversation

hjh0119
Copy link

@hjh0119 hjh0119 commented Jul 31, 2025

Support for GSPO-token as described in GSPO paper, Section 4.3.

related issue: #3811

GSPO
$w_{i}^{\mathrm{GSPO}} = \left[ \frac{\pi_{\theta}(y_i \mid x)}{\pi_{\theta_{\mathrm{old}}}(y_i \mid x)} \right]^{\frac{1}{|y_i|}} = \exp(\frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}}(y_{i, t} \mid x, y_{i, <t})})$

GSPO-token
$w_{i, t}^{\mathrm{GSPO_token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}$

where $\mathrm{sg}[\cdot]$ denotes the stop-gradient (detach) operation.

💡 NOTE: GSPO-token enables support for fine-grained (token-level) advantages.
However, given the current formulation for advantage computation, all tokens within a sentence share the same value. In this case, GSPO and GSPO-token are theoretically equivalent, as shown in equations (11) and (18) of the paper.

@LeonEricsson
Copy link
Collaborator

Thanks for this. Since GSPO-token is a generalized version of vanilla GSPO, I suggest we fully transition to GSPO-token instead of supporting both versions. Consequently, we would rename/remove importance_sampling_level, as both methods operate at the token level.

elif self.importance_sampling_level == 'sequence_token':
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient
seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)

This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like

if self.importance_sampling_level != 'token'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, so shall we move the invalid value check for importance_sampling_level into the model parameter initialization?

elif self.importance_sampling_level == 'sequence_token':
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)

This op is common across GSPO and GSPO-token, would be good to have a single variable pointing to this value under an if condition like

if self.importance_sampling_level != 'token'

@hjh0119
Copy link
Author

hjh0119 commented Aug 2, 2025

Thanks for this. Since GSPO-token is a generalized version of vanilla GSPO, I suggest we fully transition to GSPO-token instead of supporting both versions. Consequently, we would rename/remove importance_sampling_level, as both methods operate at the token level.

Agreed to keep GSPO-token. Should we retain this parameter for compatibility with previous usage, or introduce an additional parameter instead? Which is better?

@hjh0119
Copy link
Author

hjh0119 commented Aug 5, 2025

@qgallouedec @lewtun @edbeeching @kashif If there are any concerns or suggestions, please feel free to let me know. Thank you very much in advance

@LeonEricsson
Copy link
Collaborator

LeonEricsson commented Aug 5, 2025

Agreed to keep GSPO-token. Should we retain this parameter for compatibility with previous usage, or introduce an additional parameter instead? Which is better?

imo it should be removed, however, since it's already been published as part of TRL v0.20, we may need to keep it for backward comp. I can't speak to it myself, so I'll leave it to someone else to decide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants