Skip to content

[V0][Sampler] Use raw logits for greedy argmax #13312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

njhill
Copy link
Member

@njhill njhill commented Feb 14, 2025

To hopefully avoid some of the reported precision-related nondeterminism.

Also delete a vestigial intermediate method.

To hopefully avoid some of the reported precision-related nondeterminism.

Signed-off-by: Nick Hill <nhill@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@patrickvonplaten
Copy link
Contributor

Is this fix not necessary for the v1 sampler?

@njhill
Copy link
Member Author

njhill commented Mar 15, 2025

Is this fix not necessary for the v1 sampler?

@patrickvonplaten I don't think so, we were already using logits rather than logprobs for this in V1:

def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 19, 2025
@njhill njhill added the v0 label Mar 29, 2025
@tonyaw
Copy link

tonyaw commented Apr 9, 2025

@njhill, may I ask when this PR will be merged? Any target release? Thanks in advance!

@njhill
Copy link
Member Author

njhill commented Apr 9, 2025

@tonyaw there are test failures that need investigating: https://buildkite.com/vllm/ci/builds/15883#0195b060-d27c-4bd5-b435-c495d3709d24, any help would be welcome!

@gx16377
Copy link

gx16377 commented Apr 9, 2025

@njhill Hi, may I ask why argmax(logits) is more stable than argmax(logprobs)?

@tonyaw
Copy link

tonyaw commented Apr 9, 2025

@njhill Hi, may I ask why argmax(logits) is more stable than argmax(logprobs)?

@njhill, I also want to know the reason, could you please help to explain? :-)

@njhill
Copy link
Member Author

njhill commented Apr 9, 2025

@gx16377 @tonyaw logprob is softmax of the logits which is a nonlinear projection of the entire floating point range into the range [0,1]. So it essentially reduces precision and many token values will end up tied that weren't beforehand, including values tied for first place, and argmax may select from these arbitrarily (e.g. could vary by batch size).

@gx16377
Copy link

gx16377 commented Apr 10, 2025

@gx16377 @tonyaw logprob is softmax of the logits which is a nonlinear projection of the entire floating point range into the range [0,1]. So it essentially reduces precision and many token values will end up tied that weren't beforehand, including values tied for first place, and argmax may select from these arbitrarily (e.g. could vary by batch size).

Thank you

@tonyaw
Copy link

tonyaw commented Apr 10, 2025

@gx16377 @tonyaw logprob is softmax of the logits which is a nonlinear projection of the entire floating point range into the range [0,1]. So it essentially reduces precision and many token values will end up tied that weren't beforehand, including values tied for first place, and argmax may select from these arbitrarily (e.g. could vary by batch size).

Thank you

Thanks!
Could you please explain more about "could vary by batch size"?
Does "batch size" here mean concurrent request count or sth else?
If it is "concurrent request count", could you please explain more why it is related to softmax calculation?

@njhill
Copy link
Member Author

njhill commented Apr 10, 2025

@tonyaw I mean argmax. May not be deterministic when there are tied max values. And yes batch size when multiple requests are processed together.

@tonyaw
Copy link

tonyaw commented Apr 11, 2025

@njhill Thanks!
In vllm, all the calculation is done by matrix.
When multiple requests are processed together, logits of different requests are put into one matrix, so it is possible that calculation of logits of different requests may impact each other.
Is my understanding right? :-)

@saikrishb
Copy link

I mean argmax. May not be deterministic when there are tied max values.

@njhill torch documentation says argmax will pick the lowest index. At least for the case of multiple max logits, this commit will make sampling deterministic instead of picking one of the max values as mentioned here correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants