-
Notifications
You must be signed in to change notification settings - Fork 152
Add top-p sampling #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- added tests
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fantastic, @jkrukowski!
Let me try it out and check the implementation against the current one in transformers in case there are some details that could be incorporated, but this looks great already.
As a side comment, we could potentially implement the costly cumsum operation in Core ML as part of the model conversion, or using a Core ML pipeline. But using Accelerate should be more than enough for now!
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and it works fine! As expected, it's a bit slow but we can try to optimize later. Additionally, top-k and top-p could potentially coexist as pointed out below, but we can also handle that in a new PR unless you want to tackle it now :)
| if config.topK > 0 { | ||
| let topK = Math.topK(arr: logits, k: config.topK) | ||
| nextToken = Math.sample(indexes: topK.indexes, probs: topK.probs) | ||
| } else if config.topP < 1.0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding of this is correct, top-k can coexist with top-p: https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L805-L808
However, it could make sense to merge this PR now and making them coexist in a future one. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say let's merge it now, seems logical to create a separate PR with a common interface to these two
| fatalError("topP not implemented yet") | ||
| fatalError("not implemented yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make top-k compatible with top-p, we'd do a single sample call on the selected tokens and remove this fatalError.
In this PR
I've compared 2 different implementations here https://github.com/jkrukowski/topp -- looks like using Accelerate to compute a cumulative sum gives it a seed boost.