Skip to content

Commit ba5111f

Browse files
[Bugfix]: Fix the incompatibility issue with Structured Outputs when Thinking is disabled (#18879)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent 1e12352 commit ba5111f

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

vllm/v1/structured_output/__init__.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,31 +149,37 @@ def grammar_bitmask(
149149
# NOTE: This outer loop can likely be parallelized to improve
150150
# performance of bitmask generation for large batches.
151151
for req_id, _ in ordered_seq:
152-
request = requests[req_id].structured_output_request
153-
if TYPE_CHECKING:
154-
assert request is not None
155-
assert request.grammar is not None
152+
request = requests[req_id]
153+
structured_output_request = request.structured_output_request
156154

157-
apply_bitmask = (
158-
request.reasoning_ended if self.reasoner is not None else True
159-
) # noqa: E501
155+
if TYPE_CHECKING:
156+
assert structured_output_request is not None
157+
assert structured_output_request.grammar is not None
158+
apply_bitmask: bool = True
159+
if self.reasoner is not None:
160+
if structured_output_request.reasoning_ended is None:
161+
structured_output_request.reasoning_ended = \
162+
self.reasoner.is_reasoning_end(request.prompt_token_ids)
163+
apply_bitmask = structured_output_request.reasoning_ended
160164

161165
state_advancements = 0
162166
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
163167
for i, token in enumerate(req_tokens):
164-
if apply_bitmask and not request.grammar.is_terminated():
165-
request.grammar.fill_bitmask(bitmask_tensor,
166-
cumulative_index)
168+
if apply_bitmask and not \
169+
structured_output_request.grammar.is_terminated():
170+
structured_output_request.grammar.fill_bitmask(
171+
bitmask_tensor, cumulative_index)
167172
if token is not None:
168173
# In order to generate the correct bitmask for each
169174
# position in the speculative sequence, we advance
170175
# the FSM state for each speculative token and rollback
171176
# to restore the previous state when we are finished.
172-
assert request.grammar.accept_tokens(req_id, [token])
177+
assert structured_output_request.grammar.accept_tokens(
178+
req_id, [token])
173179
state_advancements += 1
174180
cumulative_index += 1
175181
if state_advancements > 0:
176-
request.grammar.rollback(state_advancements)
182+
structured_output_request.grammar.rollback(state_advancements)
177183

178184
if cumulative_index < bitmask_tensor.shape[0]:
179185
bitmask_tensor = bitmask_tensor[:cumulative_index]

vllm/v1/structured_output/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class StructuredOutputRequest:
2020
sampling_params: SamplingParams
2121
_grammar: Optional[Union[Future[StructuredOutputGrammar],
2222
StructuredOutputGrammar]] = None
23-
reasoning_ended: bool = False
23+
reasoning_ended: Optional[bool] = None
2424

2525
def _check_grammar_completion(self) -> bool:
2626
# NOTE: We have to lazy import to gate circular imports

0 commit comments

Comments
 (0)