@@ -68,10 +68,11 @@ def __init__(
68
68
# Priority queues for requests.
69
69
self .waiting : Deque [Request ] = deque ()
70
70
self .running : List [Request ] = []
71
+
71
72
# req_id -> Number of times the request has been scheduled.
72
- # We can only schedule a request more then once before the previous
73
- # scheduling step is finished when PP is enabled and the request
74
- # prompt is chunked .
73
+ # With PP, when the input prompt is divided into chunks, we can
74
+ # schedule a new chunk even before the previous chunk has completed
75
+ # the full pipeline stages. This helps reduce TTFT .
75
76
self .scheduled_req_ids : dict [str , int ] = defaultdict (int )
76
77
77
78
# The request IDs that are finished in between the previous and the
@@ -150,8 +151,7 @@ def schedule(self) -> SchedulerOutput:
150
151
if (request .num_computed_tokens >= request .num_tokens
151
152
and self .scheduled_req_ids .get (request .request_id , 0 ) > 0 ):
152
153
# We avoid re-scheduling the decoding requests because
153
- # the number of new decoded output tokens is unknown due
154
- # to speculative decoding or jump decoding.
154
+ # there is no tokens for decoding requests to be scheduled.
155
155
req_index += 1
156
156
continue
157
157
@@ -426,14 +426,17 @@ def schedule(self) -> SchedulerOutput:
426
426
grammar_bitmask = grammar_bitmask ,
427
427
)
428
428
429
- # Update the number of computed tokens for the request right after
430
- # the request is scheduled. This allows the request doing chunk prefill
431
- # to be scheduled again immediately in the next scheduling step.
432
- # If some tokens (e.g. spec tokens) are rejected later, the number of
433
- # computed tokens will be adjusted in update_from_output.
434
- for req in (scheduled_new_reqs + scheduled_resumed_reqs +
435
- scheduled_running_reqs ):
436
- req .num_computed_tokens += num_scheduled_tokens [req .request_id ]
429
+ # Advance the number of computed tokens for the request AFTER
430
+ # the request is scheduled.
431
+ # 1. The scheduler_output of the current step has to include the
432
+ # original number of scheduled tokens to determine input IDs.
433
+ # 2. Advance the number of computed tokens here allowing us to
434
+ # schedule the (prefill) request again immediately in the next
435
+ # scheduling step.
436
+ # 3. If some tokens (e.g. spec tokens) are rejected later, the number of
437
+ # computed tokens will be adjusted in update_from_output.
438
+ for req_id , num_scheduled_token in num_scheduled_tokens .items ():
439
+ self .requests [req_id ].num_computed_tokens += num_scheduled_token
437
440
438
441
self .finished_req_ids = set ()
439
442
return scheduler_output
@@ -549,6 +552,9 @@ def update_from_output(
549
552
logprobs = model_runner_output .logprobs
550
553
prompt_logprobs_dict = model_runner_output .prompt_logprobs_dict
551
554
num_scheduled_tokens = scheduler_output .num_scheduled_tokens
555
+
556
+ # We cannot use num_computed_tokens from self.requests because
557
+ # their values have been advanced when the requests are scheduled.
552
558
num_computed_tokens_at_schedule = {
553
559
req_data .req_id : req_data .num_computed_tokens
554
560
for req_data in (scheduler_output .scheduled_cached_reqs +
@@ -598,7 +604,7 @@ def update_from_output(
598
604
start_pos = request .mm_positions [input_id ]["offset" ]
599
605
num_tokens = request .mm_positions [input_id ]["length" ]
600
606
if (start_pos + num_tokens
601
- ) <= num_computed_tokens_at_schedule [req_id ]:
607
+ <= num_computed_tokens_at_schedule [req_id ]) :
602
608
# The encoder output is already processed and stored
603
609
# in the decoder's KV cache.
604
610
self .encoder_cache_manager .free_encoder_input (
0 commit comments