Skip to content

Conversation

@JartX
Copy link
Contributor

@JartX JartX commented Aug 14, 2025

This issue, introduced by a change in PR: #21862

When processing a batch that contains a mix of requests—some using guided generation (e.g., guided_json) and others being standard, non-guided requests (e.g., chat completions)—the non-guided requests fail.

The typical symptom is that the output for the non-guided request consists of a stream of repetitive characters, such as exclamation marks (!!!!!!!!!!), indicating that its vocabulary has been incorrectly masked. This issue only occurs when both types of requests are present in the same batch; batches containing only one type of request work as expected.

Root Cause Analysis

The bug is located in the apply_grammar_bitmask method within vllm/v1/worker/gpu_model_runner.py. The logical flow that leads to the error is as follows:

When a batch includes at least one guided request, the scheduler produces a grammar_bitmask numpy array. This array is compact and only contains masks for the guided requests.

Inside apply_grammar_bitmask, a new bitmask tensor, sorted_bitmask, is created to match the full size of the batch logits (i.e., one row for every request in the batch).

The error occurs here: This sorted_bitmask is initialized with zeros using np.zeros_like. In the bitmasking scheme used by xgrammar, a value of 0 instructs the system to disallow a token, whereas -1 is the value to allow all tokens.

The method then correctly copies the specific grammar masks from the scheduler's compact array into the appropriate rows of sorted_bitmask for the guided requests.

However, the rows corresponding to the non-guided requests are never updated, so they remain filled with zeros.

When this final sorted_bitmask is applied to the batch logits, it incorrectly forbids all vocabulary tokens for the non-guided requests, causing the model to produce invalid output.

Solution

The solution is to initialize the sorted_bitmask with the correct default value that allows all tokens. Instead of creating a tensor of zeros, we now create a tensor filled with -1.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug that caused non-guided requests to fail when processed in a mixed batch with guided generation requests. The root cause analysis in the description is excellent and accurately identifies that initializing the sorted_bitmask with zeros was incorrectly masking all tokens for non-guided requests. The proposed solution of using np.full to initialize the bitmask with -1 is the correct approach, as it properly allows all tokens for non-guided requests by default. The change is concise, well-targeted, and effectively resolves the issue.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find, this looks reasonable to me. Is there a unit test we could make to enforce this? cc @aarnphm @russellb @benchislett

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 14, 2025
@russellb russellb requested a review from benchislett August 14, 2025 14:04
Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thank you!

@russellb russellb enabled auto-merge (squash) August 14, 2025 14:14
@russellb
Copy link
Member

agree that test coverage would be nice, but the fix is important enough not to block on.

@JartX
Copy link
Contributor Author

JartX commented Aug 15, 2025

Hi @russellb @mgoin

I saw this pull request: https://github.com/vllm -project/vllm/pull/22963

I would say that this other approach would also have to fix the bug

@russellb
Copy link
Member

Hi @russellb @mgoin

I saw this pull request: https://github.com/vllm -project/vllm/pull/22963

I would say that this other approach would also have to fix the bug

Thanks for pointing out #22963. I'd like to merge this change even if we merge the other one as well. It makes sense we should always initialize the mask to accept all.

@russellb
Copy link
Member

@JartX it looks like your commit is missing the Signed-off-by header. Would you mind adding it?

Signed-off-by: JartX <sagformas@epdcenter.es>
auto-merge was automatically disabled August 15, 2025 13:39

Head branch was pushed to by a user without write access

@JartX JartX force-pushed the fix/guided-generation-mixed-batch-by-pr-21862 branch from 07c1611 to 44bf28a Compare August 15, 2025 13:39
@JartX
Copy link
Contributor Author

JartX commented Aug 15, 2025

@russellb done!

@russellb russellb enabled auto-merge (squash) August 15, 2025 15:09
@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80351340.

@russellb russellb merged commit 68af77e into vllm-project:main Aug 15, 2025
39 checks passed
666even666 pushed a commit to 666even666/vllm that referenced this pull request Aug 18, 2025
…ct#22896)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Yiwen Chen <yiwen66@berkeley.edu>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
…ct#22896)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…ct#22896)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants