Skip to content

Commit 6e7ceb2

Browse files
committed
Add unit test for preemption resumption
Signed-off-by: Qier Li <kevin44036@gmail.com>
1 parent 2b2ec15 commit 6e7ceb2

File tree

1 file changed

+48
-5
lines changed

1 file changed

+48
-5
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,7 +1818,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
18181818
assert len(scheduler.waiting) == 1
18191819

18201820

1821-
def test_priority_scheduling_preemption_when_out_of_kv():
1821+
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
18221822
"""Test that priority scheduling preempts lower priority requests
18231823
when out of KV cache space."""
18241824
# Create scheduler with very limited memory to force preemption
@@ -1827,6 +1827,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18271827
max_num_batched_tokens=200,
18281828
num_blocks=5, # Can hold 64 tokens (first block is null)
18291829
block_size=16, # Standard block size
1830+
use_kv_connector=True,
18301831
)
18311832

18321833
# Create a request and schedule it
@@ -1838,12 +1839,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18381839
starting_idx=0,
18391840
)[0]
18401841
scheduler.add_request(request_low)
1842+
# 1st schedule
18411843
output = scheduler.schedule()
18421844
assert len(output.scheduled_new_reqs) == 1
18431845
assert len(scheduler.waiting) == 0
18441846
assert len(scheduler.running) == 1
18451847

1846-
# Simulate model execution
1848+
# Simulate model execution - 1st decode
18471849
model_output = ModelRunnerOutput(
18481850
req_ids=[request_low.request_id],
18491851
req_id_to_index={request_low.request_id: 0},
@@ -1864,6 +1866,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18641866
starting_idx=1,
18651867
)[0]
18661868
scheduler.add_request(request_high)
1869+
# 2nd schedule
18671870
output = scheduler.schedule()
18681871
# KV cache should be full at this point
18691872
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
@@ -1872,7 +1875,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18721875
assert len(scheduler.waiting) == 0
18731876
assert len(scheduler.running) == 2
18741877

1875-
# Simulate model execution
1878+
# Simulate model execution - 2nd decode
18761879
requests = [request_low, request_high]
18771880
model_output = ModelRunnerOutput(
18781881
req_ids=[req.request_id for req in requests],
@@ -1888,7 +1891,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18881891
)
18891892
scheduler.update_from_output(output, model_output)
18901893

1891-
# Schedule again - this should trigger preemption
1894+
# 3rd schedule - this should trigger preemption
18921895
# req_low needs 32 tokens = 2 blocks
18931896
# req_high needs 33 tokens = 3 blocks
18941897
# so doesn't fit in 4 blocks.
@@ -1898,5 +1901,45 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18981901
assert len(output.scheduled_new_reqs) == 0
18991902
assert output.scheduled_cached_reqs.num_reqs == 1
19001903
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
1904+
assert scheduler.requests[
1905+
request_low.request_id].status == RequestStatus.PREEMPTED
19011906
assert len(scheduler.waiting) == 1
1902-
assert len(scheduler.running) == 1
1907+
assert len(scheduler.running) == 1
1908+
1909+
# Simulate model execution - 3rd decode
1910+
model_output = ModelRunnerOutput(
1911+
req_ids=[req.request_id for req in requests],
1912+
req_id_to_index={
1913+
req.request_id: i
1914+
for i, req in enumerate(requests)
1915+
},
1916+
sampled_token_ids=[[], [100]],
1917+
# spec_token_ids=None,
1918+
logprobs=None,
1919+
prompt_logprobs_dict={},
1920+
pooler_output=[],
1921+
)
1922+
# Finish the requests to make room for the preempted requests to resume
1923+
scheduler.update_from_output(output, model_output)
1924+
scheduler.finish_requests(request_high.request_id,
1925+
RequestStatus.FINISHED_STOPPED)
1926+
1927+
# 4th Schedule again - this should trigger the resumption
1928+
output = scheduler.schedule()
1929+
scheduled_cached_reqs = output.scheduled_cached_reqs
1930+
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
1931+
1932+
assert len(output.scheduled_new_reqs) == 0
1933+
assert scheduled_cached_reqs.num_reqs == 1
1934+
assert len(scheduler.waiting) == 0
1935+
assert len(scheduler.running) == 1
1936+
1937+
# Preempted request resumed in scheduled_cached_reqs
1938+
assert len(resumed_from_preemption) == 1
1939+
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
1940+
assert resumed_from_preemption[0]
1941+
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
1942+
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
1943+
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
1944+
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
1945+
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100

0 commit comments

Comments
 (0)