Skip to content

Commit d17f0fb

Browse files
QierLiQier Li
andauthored
[Core][KVConnector] Propagate all tokens on resumed preemptions (#24926)
Signed-off-by: Qier Li <kevin44036@gmail.com> Co-authored-by: Qier Li <qier@fb.com>
1 parent 43ab8cf commit d17f0fb

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
19501950
assert len(scheduler.waiting) == 1
19511951

19521952

1953-
def test_priority_scheduling_preemption_when_out_of_kv():
1953+
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
19541954
"""Test that priority scheduling preempts lower priority requests
19551955
when out of KV cache space."""
19561956
# Create scheduler with very limited memory to force preemption
@@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19591959
max_num_batched_tokens=200,
19601960
num_blocks=5, # Can hold 64 tokens (first block is null)
19611961
block_size=16, # Standard block size
1962+
use_kv_connector=True,
19621963
)
19631964

19641965
# Create a request and schedule it
@@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19701971
starting_idx=0,
19711972
)[0]
19721973
scheduler.add_request(request_low)
1974+
# 1st schedule
19731975
output = scheduler.schedule()
19741976
assert len(output.scheduled_new_reqs) == 1
19751977
assert len(scheduler.waiting) == 0
19761978
assert len(scheduler.running) == 1
19771979

1978-
# Simulate model execution
1980+
# Simulate model execution - 1st decode
19791981
model_output = ModelRunnerOutput(
19801982
req_ids=[request_low.request_id],
19811983
req_id_to_index={request_low.request_id: 0},
@@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19961998
starting_idx=1,
19971999
)[0]
19982000
scheduler.add_request(request_high)
2001+
# 2nd schedule
19992002
output = scheduler.schedule()
20002003
# KV cache should be full at this point
20012004
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
@@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20042007
assert len(scheduler.waiting) == 0
20052008
assert len(scheduler.running) == 2
20062009

2007-
# Simulate model execution
2010+
# Simulate model execution - 2nd decode
20082011
requests = [request_low, request_high]
20092012
model_output = ModelRunnerOutput(
20102013
req_ids=[req.request_id for req in requests],
@@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20172020
)
20182021
scheduler.update_from_output(output, model_output)
20192022

2020-
# Schedule again - this should trigger preemption
2023+
# 3rd schedule - this should trigger preemption
20212024
# req_low needs 32 tokens = 2 blocks
20222025
# req_high needs 33 tokens = 3 blocks
20232026
# so doesn't fit in 4 blocks.
@@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20272030
assert len(output.scheduled_new_reqs) == 0
20282031
assert output.scheduled_cached_reqs.num_reqs == 1
20292032
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
2033+
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
20302034
assert len(scheduler.waiting) == 1
20312035
assert len(scheduler.running) == 1
20322036

2037+
# Simulate model execution - 3rd decode
2038+
model_output = ModelRunnerOutput(
2039+
req_ids=[req.request_id for req in requests],
2040+
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
2041+
sampled_token_ids=[[], [100]],
2042+
# spec_token_ids=None,
2043+
logprobs=None,
2044+
prompt_logprobs_dict={},
2045+
pooler_output=[],
2046+
)
2047+
# Finish the requests to make room for the preempted requests to resume
2048+
scheduler.update_from_output(output, model_output)
2049+
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
2050+
2051+
# 4th Schedule - this should trigger the resumption
2052+
output = scheduler.schedule()
2053+
scheduled_cached_reqs = output.scheduled_cached_reqs
2054+
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
2055+
2056+
assert len(output.scheduled_new_reqs) == 0
2057+
assert scheduled_cached_reqs.num_reqs == 1
2058+
assert len(scheduler.waiting) == 0
2059+
assert len(scheduler.running) == 1
2060+
2061+
# Preempted request resumed in scheduled_cached_reqs
2062+
assert len(resumed_from_preemption) == 1
2063+
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
2064+
assert resumed_from_preemption[0]
2065+
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
2066+
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
2067+
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
2068+
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
2069+
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
2070+
20332071

20342072
@pytest.mark.parametrize(
20352073
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
257257
req_ids=[req_id],
258258
resumed_from_preemption=[False],
259259
new_token_ids=[[]],
260+
resumed_req_token_ids=[None],
260261
new_block_ids=([[0]],),
261262
num_computed_tokens=[0],
262263
num_output_tokens=[0],

vllm/v1/core/sched/output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ class CachedRequestData:
9898
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
9999
# When PP is not used, new_token_ids will be empty.
100100
new_token_ids: list[list[int]]
101+
# If resumed_from_preemption is True, propogate the token ids to the
102+
# connector, otherwise will be empty.
103+
resumed_req_token_ids: list[list[int] | None]
101104
new_block_ids: list[tuple[list[int], ...] | None]
102105
num_computed_tokens: list[int]
103106
num_output_tokens: list[int]
@@ -112,6 +115,7 @@ def make_empty(cls) -> CachedRequestData:
112115
req_ids=[],
113116
resumed_from_preemption=[],
114117
new_token_ids=[],
118+
resumed_req_token_ids=[],
115119
new_block_ids=[],
116120
num_computed_tokens=[],
117121
num_output_tokens=[],

vllm/v1/core/sched/scheduler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -709,10 +709,15 @@ def _make_cached_request_data(
709709
req_ids: list[str] = []
710710
new_token_ids: list[list[int]] = []
711711
new_block_ids: list[tuple[list[int], ...] | None] = []
712+
resumed_req_token_ids: list[list[int] | None] = []
712713
num_computed_tokens: list[int] = []
713714
num_output_tokens: list[int] = []
714715

715-
for req in itertools.chain(running_reqs, resumed_reqs):
716+
# Because resumed_reqs is usually empty, it is more efficient to do
717+
# in-place appending so that we don't need to allocate a new list.
718+
resumed_from_preemption = [False] * len(running_reqs)
719+
resumed_from_preemption += [True] * len(resumed_reqs)
720+
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
716721
req_id = req.request_id
717722
req_ids.append(req_id)
718723
num_tokens = num_scheduled_tokens[req_id] - len(
@@ -728,20 +733,23 @@ def _make_cached_request_data(
728733
req.num_computed_tokens : req.num_computed_tokens + num_tokens
729734
]
730735
new_token_ids.append(token_ids)
736+
resumed_token_ids = None
737+
if resumed_from_preemption[idx]:
738+
resumed_token_ids = req.all_token_ids[
739+
: req.num_computed_tokens + num_tokens
740+
]
741+
resumed_req_token_ids.append(resumed_token_ids)
731742
new_block_ids.append(
732743
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
733744
)
734745
num_computed_tokens.append(req.num_computed_tokens)
735746
num_output_tokens.append(req.num_output_tokens)
736-
# Because resumed_reqs is usually empty, it is more efficient to do
737-
# in-place appending so that we don't need to allocate a new list.
738-
resumed_from_preemption = [False] * len(running_reqs)
739-
resumed_from_preemption += [True] * len(resumed_reqs)
740747

741748
return CachedRequestData(
742749
req_ids=req_ids,
743750
resumed_from_preemption=resumed_from_preemption,
744751
new_token_ids=new_token_ids,
752+
resumed_req_token_ids=resumed_req_token_ids,
745753
new_block_ids=new_block_ids,
746754
num_computed_tokens=num_computed_tokens,
747755
num_output_tokens=num_output_tokens,

0 commit comments

Comments
 (0)