2
2
3
3
from __future__ import annotations
4
4
5
+ import itertools
5
6
import time
6
7
from collections import deque
7
8
from collections .abc import Iterable
@@ -144,7 +145,7 @@ def schedule(self) -> SchedulerOutput:
144
145
# uses structured decoding.
145
146
structured_output_request_ids : dict [str , int ] = {}
146
147
147
- req_to_new_block_ids : dict [str , list [int ]] = {}
148
+ req_to_new_block_ids : dict [str , list [list [ int ] ]] = {}
148
149
num_scheduled_tokens : dict [str , int ] = {}
149
150
token_budget = self .max_num_scheduled_tokens
150
151
# Encoder-related.
@@ -165,7 +166,8 @@ def schedule(self) -> SchedulerOutput:
165
166
req_index += 1
166
167
continue
167
168
168
- num_new_tokens = (request .num_tokens_with_spec -
169
+ num_draft_tokens = len (request .draft_token_ids )
170
+ num_new_tokens = (request .num_tokens + num_draft_tokens -
169
171
request .num_computed_tokens )
170
172
if (0 < self .scheduler_config .long_prefill_token_threshold <
171
173
num_new_tokens ):
@@ -196,7 +198,8 @@ def schedule(self) -> SchedulerOutput:
196
198
while True :
197
199
new_blocks = self .kv_cache_manager .allocate_slots (
198
200
request ,
199
- num_new_tokens ,
201
+ num_new_tokens - num_draft_tokens ,
202
+ num_draft_tokens = num_draft_tokens ,
200
203
num_lookahead_tokens = self .num_lookahead_tokens )
201
204
if new_blocks is None :
202
205
# The request cannot be scheduled.
@@ -233,7 +236,7 @@ def schedule(self) -> SchedulerOutput:
233
236
# cycle to fill in the bitmask, which could be a big no-op.
234
237
structured_output_request_ids [request .request_id ] = req_index
235
238
req_to_new_block_ids [request .request_id ] = [
236
- b .block_id for b in new_blocks
239
+ [ b .block_id for b in blocks ] for blocks in new_blocks
237
240
]
238
241
num_scheduled_tokens [request .request_id ] = num_new_tokens
239
242
token_budget -= num_new_tokens
@@ -330,7 +333,11 @@ def schedule(self) -> SchedulerOutput:
330
333
new_encoder_budget = encoder_budget
331
334
332
335
new_blocks = self .kv_cache_manager .allocate_slots (
333
- request , num_new_tokens , num_computed_tokens , computed_blocks )
336
+ request ,
337
+ num_new_tokens ,
338
+ new_computed_tokens = num_computed_tokens ,
339
+ new_computed_blocks = computed_blocks ,
340
+ num_lookahead_tokens = self .num_lookahead_tokens )
334
341
if new_blocks is None :
335
342
# The request cannot be scheduled.
336
343
break
@@ -355,9 +362,9 @@ def schedule(self) -> SchedulerOutput:
355
362
356
363
if self .lora_config and request .lora_request :
357
364
scheduled_loras .add (request .lora_request .lora_int_id )
358
- req_to_new_block_ids [request .request_id ] = [
359
- b .block_id for b in computed_blocks + new_blocks
360
- ]
365
+ req_to_new_block_ids [request .request_id ] = [[
366
+ b .block_id for b in itertools . chain ( b1 , b2 )
367
+ ] for b1 , b2 in zip ( computed_blocks , new_blocks )]
361
368
num_scheduled_tokens [request .request_id ] = num_new_tokens
362
369
token_budget -= num_new_tokens
363
370
request .status = RequestStatus .RUNNING
@@ -463,7 +470,7 @@ def _make_cached_request_data(
463
470
request : Request ,
464
471
num_scheduled_tokens : int ,
465
472
num_scheduled_spec_tokens : int ,
466
- new_block_ids : list [int ],
473
+ new_block_ids : list [list [ int ] ],
467
474
resumed_from_preemption : bool ,
468
475
) -> CachedRequestData :
469
476
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
0 commit comments