-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V0][Sampler] Use raw logits for greedy argmax #13312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
To hopefully avoid some of the reported precision-related nondeterminism. Signed-off-by: Nick Hill <nhill@redhat.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Is this fix not necessary for the v1 sampler? |
@patrickvonplaten I don't think so, we were already using logits rather than logprobs for this in V1: vllm/vllm/v1/sample/sampler.py Lines 77 to 78 in 9f37422
|
@njhill, may I ask when this PR will be merged? Any target release? Thanks in advance! |
@tonyaw there are test failures that need investigating: https://buildkite.com/vllm/ci/builds/15883#0195b060-d27c-4bd5-b435-c495d3709d24, any help would be welcome! |
@njhill Hi, may I ask why argmax(logits) is more stable than argmax(logprobs)? |
@gx16377 @tonyaw logprob is softmax of the logits which is a nonlinear projection of the entire floating point range into the range [0,1]. So it essentially reduces precision and many token values will end up tied that weren't beforehand, including values tied for first place, and argmax may select from these arbitrarily (e.g. could vary by batch size). |
Thank you |
Thanks! |
@tonyaw I mean argmax. May not be deterministic when there are tied max values. And yes batch size when multiple requests are processed together. |
@njhill Thanks! |
@njhill torch documentation says argmax will pick the lowest index. At least for the case of multiple max logits, this commit will make sampling deterministic instead of picking one of the max values as mentioned here correct? |
To hopefully avoid some of the reported precision-related nondeterminism.
Also delete a vestigial intermediate method.