@@ -1769,28 +1769,22 @@ def prepare_structured_decoding_input(
17691769 self .grammar_bitmask_cpu .zero_ ()
17701770 self .require_structured_out_cpu .zero_ ()
17711771
1772- # We receive the structured output bitmask from the scheduler, but the
1773- # indices of the requests in the batch may not match the indices of
1774- # the bitmask since the scheduler doesn't know how the tpu runner is
1775- # ordering the requests in the batch. We need to match the order of
1776- # bitmask with the order of requests
1777- struct_out_indices : list [int ] = []
1778- mask_indices : list [int ] = []
1779- for req_id in self .input_batch .req_ids :
1780- mask_index = scheduler_output .structured_output_request_ids .get (
1781- req_id )
1782- if mask_index is None :
1772+ sorted_struct_requests = sorted (
1773+ scheduler_output .structured_output_request_ids .items (),
1774+ key = lambda item : item [1 ])
1775+ cumulative_mask_idx = 0
1776+ for req_id , _ in sorted_struct_requests :
1777+ if req_id not in self .input_batch .req_id_to_index :
17831778 continue
17841779 batch_index = self .input_batch .req_id_to_index [req_id ]
1785- struct_out_indices .append (batch_index )
1786- mask_indices .append (mask_index )
1787- self .grammar_bitmask_cpu [struct_out_indices ] = torch .from_numpy (
1788- grammar_bitmask [mask_indices ])
1789- # It's not guaranteed that all requests in this batch require
1790- # structured output, so create a bool tensor to represent
1791- # the requests that need structured output.
1792- struct_out_indices = torch .tensor (struct_out_indices , dtype = torch .long )
1793- self .require_structured_out_cpu [struct_out_indices ] = True
1780+ self .grammar_bitmask_cpu [batch_index ] = torch .from_numpy (
1781+ grammar_bitmask [cumulative_mask_idx ])
1782+ # It's not guaranteed that all requests in this batch require
1783+ # structured output, so create a bool tensor to represent
1784+ # the requests that need structured output.
1785+ self .require_structured_out_cpu [batch_index ] = True
1786+ cumulative_mask_idx += 1
1787+
17941788 return self .require_structured_out_cpu [:num_reqs ].to (logits .device ), \
17951789 self .grammar_bitmask_cpu [:num_reqs ].to (logits .device ), \
17961790 self .structured_decode_arange .to (logits .device )
0 commit comments