Skip to content

Commit b796205

Browse files
author
Qier Li
committed
Propagate entrie tokens to connector for resumed preemptions
1 parent 60bc25e commit b796205

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
18601860
assert len(scheduler.waiting) == 1
18611861

18621862

1863-
def test_priority_scheduling_preemption_when_out_of_kv():
1863+
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
18641864
"""Test that priority scheduling preempts lower priority requests
18651865
when out of KV cache space."""
18661866
# Create scheduler with very limited memory to force preemption
@@ -1869,6 +1869,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18691869
max_num_batched_tokens=200,
18701870
num_blocks=5, # Can hold 64 tokens (first block is null)
18711871
block_size=16, # Standard block size
1872+
use_kv_connector=True,
18721873
)
18731874

18741875
# Create a request and schedule it
@@ -1880,12 +1881,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18801881
starting_idx=0,
18811882
)[0]
18821883
scheduler.add_request(request_low)
1884+
# 1st schedule
18831885
output = scheduler.schedule()
18841886
assert len(output.scheduled_new_reqs) == 1
18851887
assert len(scheduler.waiting) == 0
18861888
assert len(scheduler.running) == 1
18871889

1888-
# Simulate model execution
1890+
# Simulate model execution - 1st decode
18891891
model_output = ModelRunnerOutput(
18901892
req_ids=[request_low.request_id],
18911893
req_id_to_index={request_low.request_id: 0},
@@ -1906,6 +1908,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19061908
starting_idx=1,
19071909
)[0]
19081910
scheduler.add_request(request_high)
1911+
# 2nd schedule
19091912
output = scheduler.schedule()
19101913
# KV cache should be full at this point
19111914
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
@@ -1914,7 +1917,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19141917
assert len(scheduler.waiting) == 0
19151918
assert len(scheduler.running) == 2
19161919

1917-
# Simulate model execution
1920+
# Simulate model execution - 2nd decode
19181921
requests = [request_low, request_high]
19191922
model_output = ModelRunnerOutput(
19201923
req_ids=[req.request_id for req in requests],
@@ -1927,7 +1930,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19271930
)
19281931
scheduler.update_from_output(output, model_output)
19291932

1930-
# Schedule again - this should trigger preemption
1933+
# 3rd schedule - this should trigger preemption
19311934
# req_low needs 32 tokens = 2 blocks
19321935
# req_high needs 33 tokens = 3 blocks
19331936
# so doesn't fit in 4 blocks.
@@ -1937,9 +1940,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19371940
assert len(output.scheduled_new_reqs) == 0
19381941
assert output.scheduled_cached_reqs.num_reqs == 1
19391942
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
1943+
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
19401944
assert len(scheduler.waiting) == 1
19411945
assert len(scheduler.running) == 1
19421946

1947+
# Simulate model execution - 3rd decode
1948+
model_output = ModelRunnerOutput(
1949+
req_ids=[req.request_id for req in requests],
1950+
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
1951+
sampled_token_ids=[[], [100]],
1952+
# spec_token_ids=None,
1953+
logprobs=None,
1954+
prompt_logprobs_dict={},
1955+
pooler_output=[],
1956+
)
1957+
# Finish the requests to make room for the preempted requests to resume
1958+
scheduler.update_from_output(output, model_output)
1959+
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
1960+
1961+
# 4th Schedule - this should trigger the resumption
1962+
output = scheduler.schedule()
1963+
scheduled_cached_reqs = output.scheduled_cached_reqs
1964+
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
1965+
1966+
assert len(output.scheduled_new_reqs) == 0
1967+
assert scheduled_cached_reqs.num_reqs == 1
1968+
assert len(scheduler.waiting) == 0
1969+
assert len(scheduler.running) == 1
1970+
1971+
# Preempted request resumed in scheduled_cached_reqs
1972+
assert len(resumed_from_preemption) == 1
1973+
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
1974+
assert resumed_from_preemption[0]
1975+
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
1976+
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
1977+
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
1978+
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
1979+
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
1980+
19431981

19441982
@pytest.mark.parametrize(
19451983
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
257257
req_ids=[req_id],
258258
resumed_from_preemption=[False],
259259
new_token_ids=[[]],
260+
resumed_req_token_ids=[None],
260261
new_block_ids=([[0]],),
261262
num_computed_tokens=[0],
262263
num_output_tokens=[0],

vllm/v1/core/sched/output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ class CachedRequestData:
9898
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
9999
# When PP is not used, new_token_ids will be empty.
100100
new_token_ids: list[list[int]]
101+
# If resumed_from_preemption is True, propogate the token ids to the
102+
# connector, otherwise will be empty.
103+
resumed_req_token_ids: list[list[int] | None]
101104
new_block_ids: list[tuple[list[int], ...] | None]
102105
num_computed_tokens: list[int]
103106
num_output_tokens: list[int]
@@ -112,6 +115,7 @@ def make_empty(cls) -> CachedRequestData:
112115
req_ids=[],
113116
resumed_from_preemption=[],
114117
new_token_ids=[],
118+
resumed_req_token_ids=[],
115119
new_block_ids=[],
116120
num_computed_tokens=[],
117121
num_output_tokens=[],

vllm/v1/core/sched/scheduler.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,16 @@ def _make_cached_request_data(
709709
req_ids: list[str] = []
710710
new_token_ids: list[list[int]] = []
711711
new_block_ids: list[tuple[list[int], ...] | None] = []
712+
resumed_req_token_ids: list[list[int] | None] = []
712713
num_computed_tokens: list[int] = []
713714
num_output_tokens: list[int] = []
714715

715716
use_connector = self.connector is not None
716-
for req in itertools.chain(running_reqs, resumed_reqs):
717+
# Because resumed_reqs is usually empty, it is more efficient to do
718+
# in-place appending so that we don't need to allocate a new list.
719+
resumed_from_preemption = [False] * len(running_reqs)
720+
resumed_from_preemption += [True] * len(resumed_reqs)
721+
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
717722
req_id = req.request_id
718723
req_ids.append(req_id)
719724
num_tokens = num_scheduled_tokens[req_id] - len(
@@ -729,25 +734,29 @@ def _make_cached_request_data(
729734
req.num_computed_tokens : req.num_computed_tokens + num_tokens
730735
]
731736
new_token_ids.append(token_ids)
737+
resumed_req_token_ids.append(None)
732738
elif use_connector:
733739
# When using a KVConnector, we add a placeholder to avoid index
734740
# out of bounds errors. TODO: Remove this once the KVConnector
735741
# is updated to handle token IDs properly.
736742
new_token_ids.append([])
743+
resumed_token_ids = None
744+
if resumed_from_preemption[idx]:
745+
resumed_token_ids = req.all_token_ids[
746+
: req.num_computed_tokens + num_tokens
747+
]
748+
resumed_req_token_ids.append(resumed_token_ids)
737749
new_block_ids.append(
738750
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
739751
)
740752
num_computed_tokens.append(req.num_computed_tokens)
741753
num_output_tokens.append(len(req.output_token_ids))
742-
# Because resumed_reqs is usually empty, it is more efficient to do
743-
# in-place appending so that we don't need to allocate a new list.
744-
resumed_from_preemption = [False] * len(running_reqs)
745-
resumed_from_preemption += [True] * len(resumed_reqs)
746754

747755
return CachedRequestData(
748756
req_ids=req_ids,
749757
resumed_from_preemption=resumed_from_preemption,
750758
new_token_ids=new_token_ids,
759+
resumed_req_token_ids=resumed_req_token_ids,
751760
new_block_ids=new_block_ids,
752761
num_computed_tokens=num_computed_tokens,
753762
num_output_tokens=num_output_tokens,

0 commit comments

Comments
 (0)