@@ -149,31 +149,37 @@ def grammar_bitmask(
149
149
# NOTE: This outer loop can likely be parallelized to improve
150
150
# performance of bitmask generation for large batches.
151
151
for req_id , _ in ordered_seq :
152
- request = requests [req_id ].structured_output_request
153
- if TYPE_CHECKING :
154
- assert request is not None
155
- assert request .grammar is not None
152
+ request = requests [req_id ]
153
+ structured_output_request = request .structured_output_request
156
154
157
- apply_bitmask = (
158
- request .reasoning_ended if self .reasoner is not None else True
159
- ) # noqa: E501
155
+ if TYPE_CHECKING :
156
+ assert structured_output_request is not None
157
+ assert structured_output_request .grammar is not None
158
+ apply_bitmask : bool = True
159
+ if self .reasoner is not None :
160
+ if structured_output_request .reasoning_ended is None :
161
+ structured_output_request .reasoning_ended = \
162
+ self .reasoner .is_reasoning_end (request .prompt_token_ids )
163
+ apply_bitmask = structured_output_request .reasoning_ended
160
164
161
165
state_advancements = 0
162
166
req_tokens = scheduled_spec_decode_tokens .get (req_id , []) + [None ]
163
167
for i , token in enumerate (req_tokens ):
164
- if apply_bitmask and not request .grammar .is_terminated ():
165
- request .grammar .fill_bitmask (bitmask_tensor ,
166
- cumulative_index )
168
+ if apply_bitmask and not \
169
+ structured_output_request .grammar .is_terminated ():
170
+ structured_output_request .grammar .fill_bitmask (
171
+ bitmask_tensor , cumulative_index )
167
172
if token is not None :
168
173
# In order to generate the correct bitmask for each
169
174
# position in the speculative sequence, we advance
170
175
# the FSM state for each speculative token and rollback
171
176
# to restore the previous state when we are finished.
172
- assert request .grammar .accept_tokens (req_id , [token ])
177
+ assert structured_output_request .grammar .accept_tokens (
178
+ req_id , [token ])
173
179
state_advancements += 1
174
180
cumulative_index += 1
175
181
if state_advancements > 0 :
176
- request .grammar .rollback (state_advancements )
182
+ structured_output_request .grammar .rollback (state_advancements )
177
183
178
184
if cumulative_index < bitmask_tensor .shape [0 ]:
179
185
bitmask_tensor = bitmask_tensor [:cumulative_index ]
0 commit comments