Skip to content

Commit c7d58f4

Browse files
committed
Add interface with is_reasoning_end_streaming for reasoning parsers
Signed-off-by: hdlj-h <hubert@hcompany.ai>
1 parent a2f7ca7 commit c7d58f4

File tree

6 files changed

+72
-6
lines changed

6 files changed

+72
-6
lines changed

docs/features/reasoning_outputs.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
299299

300300
def is_reasoning_end(self, input_ids: list[int]) -> bool:
301301
return self.end_token_id in input_ids
302+
303+
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
304+
return self.end_token_id in delta_token_ids
302305
...
303306
```
304307

tests/reasoning/test_base_thinking_reasoning_parser.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,41 @@ def test_is_reasoning_end(self, test_tokenizer):
132132
is False
133133
)
134134

135+
def test_is_reasoning_end_streaming(self, test_tokenizer):
136+
"""Test the is_reasoning_end_streaming method."""
137+
parser = TestThinkingReasoningParser(test_tokenizer)
138+
end_token_id = parser.end_token_id
139+
start_token_id = parser.start_token_id
140+
141+
assert (
142+
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
143+
is True
144+
)
145+
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
146+
assert parser.is_reasoning_end_streaming([], []) is False
147+
assert (
148+
parser.is_reasoning_end_streaming(
149+
[1, start_token_id, 2, end_token_id], [end_token_id]
150+
)
151+
is True
152+
)
153+
assert (
154+
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
155+
)
156+
assert (
157+
parser.is_reasoning_end_streaming(
158+
[1, start_token_id, 2, end_token_id, 2, start_token_id],
159+
[end_token_id, 2, start_token_id],
160+
)
161+
is False
162+
)
163+
assert (
164+
parser.is_reasoning_end_streaming(
165+
[1, start_token_id, 2, end_token_id, 2, 2], [2]
166+
)
167+
is False
168+
)
169+
135170
def test_extract_content_ids(self, test_tokenizer):
136171
"""Test the extract_content_ids method."""
137172
parser = TestThinkingReasoningParser(test_tokenizer)

tests/v1/structured_output/test_reasoning_structured_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def mock_request_with_structured_output(self):
7070
request.use_structured_output = True
7171
request.prompt_token_ids = [1, 2, 3, 4, 5]
7272
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
73-
request.num_computed_tokens = 3
73+
request.num_computed_tokens = 5
7474
return request
7575

7676
def test_should_fill_bitmask_with_enable_in_reasoning(

vllm/reasoning/abs_reasoning_parsers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
5252
Check if the reasoning content ends in the input_ids.
5353
5454
It is used in structured engines like `xgrammar` to check if the
55-
reasoning content ends in the model output. `input_ids` can be
56-
either the entire model output or the last few computed tokens of
57-
the model output (like during a decode step).
55+
reasoning content ends in the model output.
5856
5957
Parameters:
6058
input_ids: list[int]
@@ -65,6 +63,31 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
6563
True if the reasoning content ends in the input_ids.
6664
"""
6765

66+
def is_reasoning_end_streaming(
67+
self, input_ids: list[int], delta_ids: list[int]
68+
) -> bool:
69+
"""
70+
Check if the reasoning content ends in the input_ids on a
71+
decode step.
72+
73+
It is used in structured engines like `xgrammar` to check if the
74+
reasoning content ends in the model output during a decode step.
75+
`input_ids` the entire model output and `delta_ids` are the last few
76+
computed tokens of the model output (like during a decode step).
77+
78+
Parameters:
79+
input_ids: list[int]
80+
The entire model output.
81+
delta_ids: list[int]
82+
The last few computed tokens of the model output at the current decode step.
83+
84+
Returns:
85+
bool
86+
True if the reasoning content ends in the `delta_ids` on a
87+
decode step.
88+
"""
89+
return self.is_reasoning_end(input_ids)
90+
6891
@abstractmethod
6992
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
7093
"""

vllm/reasoning/basic_parsers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
7474
return True
7575
return False
7676

77+
def is_reasoning_end_streaming(
78+
self, input_ids: list[int], delta_ids: list[int]
79+
) -> bool:
80+
return self.is_reasoning_end(delta_ids)
81+
7782
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
7883
"""
7984
Extract the content after the end tokens

vllm/v1/structured_output/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def should_advance(self, request: Request) -> bool:
339339
return True
340340

341341
# Check if reasoning ends in *this* step
342-
if self.reasoner.is_reasoning_end(
343-
request.all_token_ids[request.num_computed_tokens :]
342+
if self.reasoner.is_reasoning_end_streaming(
343+
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
344344
):
345345
# Reasoning just ended, so we shouldn't advance til
346346
# next pass

0 commit comments

Comments
 (0)