Skip to content

[Bugfix][V0] Another multi-sequence logprobs streaming edge case #16805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Apr 17, 2025

PR #15259 fixed an edge case with multi-sequence log probs but seems to have created another edge case that leads to missing chunks in the streamed response:

Using:

VLLM_USE_V1=0 vllm serve meta-llama/Llama-3.2-3B-Instruct
curl -s http://localhost:8000/v1/completions     \
    -H "Content-Type: application/json"     -d '{
        "model": "meta-llama/Llama-3.2-3B-Instruct",
        "prompt": "Hello!",
        "max_tokens": 2,
        "stream_options": {"include_usage": true},
        "stream": true, 
        "n": 2, 
        "logprobs":0,
        "temperature": 1.0
    }'

Would result in:

data: {"id":"cmpl-23375c79d8744f7ea7bfbe83acbb099f","object":"text_completion","created":1744923754,"model":"meta-llama/Llama-3.2-3B-Instruct","choices":[{"index":0,"text":" ","logprobs":{"text_offset":[0],"token_logprobs":[0.0],"tokens":[" "],"top_logprobs":[{" ":0.0}]},"finish_reason":null,"stop_reason":null}],"usage":null}

data: {"id":"cmpl-23375c79d8744f7ea7bfbe83acbb099f","object":"text_completion","created":1744923754,"model":"meta-llama/Llama-3.2-3B-Instruct","choices":[{"index":1,"text":" ","logprobs":{"text_offset":[0],"token_logprobs":[0.0],"tokens":[" "],"top_logprobs":[{" ":0.0}]},"finish_reason":null,"stop_reason":null}],"usage":null}

data: {"error": {"object": "error", "message": "Did not output logprobs", "type": "BadRequestError", "param": null, "code": 400}}

data: [DONE]

Note that the final chunks that would set finish_reason and return the usage are missing.

The fix proposed in this PR is to only require logprobs if there are non-empty token_ids being output.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Apr 17, 2025
@DarkLight1337
Copy link
Member

Can we add some tests to avoid regressions?

@tjohnson31415 tjohnson31415 force-pushed the fix-multiseq-logprobs branch from c5130fa to 973b1e7 Compare April 22, 2025 15:43
@mergify mergify bot added the ci/build label Apr 22, 2025
vllm/sequence.py Outdated
first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
last_remaining_id = list(self.to_be_finished)[-1]
if seq_group.request_id == last_remaining_id:
Copy link
Contributor Author

@tjohnson31415 tjohnson31415 Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andylolu2 I saw that you made the change to this code in #11898 to select the first sequence. I don't really understand this code, so please check my fix here:

I found that if multiple sequences finished at the same time, then the first sequence in the group would return the assembled_seq_group, then it would be removed as a finished sequence, then the next sequence would be selected and assembled_seq_group returned again. This seemed to prevent the last tokens being streamed back to the client in the final chunk (the chunk would have the finish_reason but no text). My understanding from staring at this code is that we only want 1 seq to return the assembed_seq_group, so I updated this to return the last remaining id and rely on the sequences being finished in order. Does that make sense or is there a better fix?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat @WoosukKwon @andylolu2 Friendly ping for feedback, Thanks! :)

Copy link
Contributor

@andylolu2 andylolu2 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a nasty bug. I don't think "sequences being finished in order" is reliable at all, given it's just dict keys. You would run into same bug equally often I believe.

I think the proper fix would be to delay the removing the request from self.to_be_finished until every output of that step is processed, currently done here:

vllm/vllm/outputs.py

Lines 182 to 183 in 289199f

if finished:
group.finish_seq(seq_group)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andylolu2 Thanks for the feedback! Yeah, agreed about relying on the ordering...
Though I can't see how to move the finish_seq cleanup after the loop. The looping over the finished requests happens in the llm_engine:

vllm/vllm/engine/llm_engine.py

Lines 1118 to 1130 in 8568650

for i in finished_now:
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

And it looks like the non-streaming case reiles on sequences being finished during the loop... So I didn't see a way to move the call to finish_seq without doing a bigger set of changes.

After staring at the code longer, I have found an alternative fix which is simple but seems to work: make the streaming case only handle non-finished seqs and use the same logic as non-streaming when all sequences are finished.

Please take a look to see if you that makes sense to you.

@tjohnson31415
Copy link
Contributor Author

tjohnson31415 commented Apr 22, 2025

@DarkLight1337 Added a test. I found the test_regression.py file and moved that into a tests/regressions suite and added new file for openai requests with the thought that more regression tests could be added there.

While writing the test, I also noticed that I can request max_tokens: 3 but the chunks returned only had 2 token for each sequence. I came up with a fix for that bug too, but I'm not confident about it (see comment thread on the change).

@njhill njhill added the v0 label May 1, 2025
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
@tjohnson31415 tjohnson31415 force-pushed the fix-multiseq-logprobs branch 2 times, most recently from a9cd354 to b8ca5e5 Compare May 15, 2025 17:17
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
@tjohnson31415 tjohnson31415 force-pushed the fix-multiseq-logprobs branch from b8ca5e5 to 4a7a17a Compare May 15, 2025 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants