Skip to content

Commit 97cccb9

Browse files
Chenyaaangxuebwang-amd
authored andcommitted
[TPU] Fix tpu structured decoding in mixed batches (vllm-project#24458)
Signed-off-by: Chenyaaang <chenyangli@google.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 1296dee commit 97cccb9

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,28 +1769,22 @@ def prepare_structured_decoding_input(
17691769
self.grammar_bitmask_cpu.zero_()
17701770
self.require_structured_out_cpu.zero_()
17711771

1772-
# We receive the structured output bitmask from the scheduler, but the
1773-
# indices of the requests in the batch may not match the indices of
1774-
# the bitmask since the scheduler doesn't know how the tpu runner is
1775-
# ordering the requests in the batch. We need to match the order of
1776-
# bitmask with the order of requests
1777-
struct_out_indices: list[int] = []
1778-
mask_indices: list[int] = []
1779-
for req_id in self.input_batch.req_ids:
1780-
mask_index = scheduler_output.structured_output_request_ids.get(
1781-
req_id)
1782-
if mask_index is None:
1772+
sorted_struct_requests = sorted(
1773+
scheduler_output.structured_output_request_ids.items(),
1774+
key=lambda item: item[1])
1775+
cumulative_mask_idx = 0
1776+
for req_id, _ in sorted_struct_requests:
1777+
if req_id not in self.input_batch.req_id_to_index:
17831778
continue
17841779
batch_index = self.input_batch.req_id_to_index[req_id]
1785-
struct_out_indices.append(batch_index)
1786-
mask_indices.append(mask_index)
1787-
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
1788-
grammar_bitmask[mask_indices])
1789-
# It's not guaranteed that all requests in this batch require
1790-
# structured output, so create a bool tensor to represent
1791-
# the requests that need structured output.
1792-
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
1793-
self.require_structured_out_cpu[struct_out_indices] = True
1780+
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
1781+
grammar_bitmask[cumulative_mask_idx])
1782+
# It's not guaranteed that all requests in this batch require
1783+
# structured output, so create a bool tensor to represent
1784+
# the requests that need structured output.
1785+
self.require_structured_out_cpu[batch_index] = True
1786+
cumulative_mask_idx += 1
1787+
17941788
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
17951789
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
17961790
self.structured_decode_arange.to(logits.device)

0 commit comments

Comments
 (0)