@@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
19501950 assert len (scheduler .waiting ) == 1
19511951
19521952
1953- def test_priority_scheduling_preemption_when_out_of_kv ():
1953+ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv ():
19541954 """Test that priority scheduling preempts lower priority requests
19551955 when out of KV cache space."""
19561956 # Create scheduler with very limited memory to force preemption
@@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19591959 max_num_batched_tokens = 200 ,
19601960 num_blocks = 5 , # Can hold 64 tokens (first block is null)
19611961 block_size = 16 , # Standard block size
1962+ use_kv_connector = True ,
19621963 )
19631964
19641965 # Create a request and schedule it
@@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19701971 starting_idx = 0 ,
19711972 )[0 ]
19721973 scheduler .add_request (request_low )
1974+ # 1st schedule
19731975 output = scheduler .schedule ()
19741976 assert len (output .scheduled_new_reqs ) == 1
19751977 assert len (scheduler .waiting ) == 0
19761978 assert len (scheduler .running ) == 1
19771979
1978- # Simulate model execution
1980+ # Simulate model execution - 1st decode
19791981 model_output = ModelRunnerOutput (
19801982 req_ids = [request_low .request_id ],
19811983 req_id_to_index = {request_low .request_id : 0 },
@@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
19961998 starting_idx = 1 ,
19971999 )[0 ]
19982000 scheduler .add_request (request_high )
2001+ # 2nd schedule
19992002 output = scheduler .schedule ()
20002003 # KV cache should be full at this point
20012004 assert scheduler .kv_cache_manager .block_pool .get_num_free_blocks () == 0
@@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20042007 assert len (scheduler .waiting ) == 0
20052008 assert len (scheduler .running ) == 2
20062009
2007- # Simulate model execution
2010+ # Simulate model execution - 2nd decode
20082011 requests = [request_low , request_high ]
20092012 model_output = ModelRunnerOutput (
20102013 req_ids = [req .request_id for req in requests ],
@@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20172020 )
20182021 scheduler .update_from_output (output , model_output )
20192022
2020- # Schedule again - this should trigger preemption
2023+ # 3rd schedule - this should trigger preemption
20212024 # req_low needs 32 tokens = 2 blocks
20222025 # req_high needs 33 tokens = 3 blocks
20232026 # so doesn't fit in 4 blocks.
@@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
20272030 assert len (output .scheduled_new_reqs ) == 0
20282031 assert output .scheduled_cached_reqs .num_reqs == 1
20292032 assert output .scheduled_cached_reqs .req_ids [0 ] == request_high .request_id
2033+ assert scheduler .requests [request_low .request_id ].status == RequestStatus .PREEMPTED
20302034 assert len (scheduler .waiting ) == 1
20312035 assert len (scheduler .running ) == 1
20322036
2037+ # Simulate model execution - 3rd decode
2038+ model_output = ModelRunnerOutput (
2039+ req_ids = [req .request_id for req in requests ],
2040+ req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2041+ sampled_token_ids = [[], [100 ]],
2042+ # spec_token_ids=None,
2043+ logprobs = None ,
2044+ prompt_logprobs_dict = {},
2045+ pooler_output = [],
2046+ )
2047+ # Finish the requests to make room for the preempted requests to resume
2048+ scheduler .update_from_output (output , model_output )
2049+ scheduler .finish_requests (request_high .request_id , RequestStatus .FINISHED_STOPPED )
2050+
2051+ # 4th Schedule - this should trigger the resumption
2052+ output = scheduler .schedule ()
2053+ scheduled_cached_reqs = output .scheduled_cached_reqs
2054+ resumed_from_preemption = scheduled_cached_reqs .resumed_from_preemption
2055+
2056+ assert len (output .scheduled_new_reqs ) == 0
2057+ assert scheduled_cached_reqs .num_reqs == 1
2058+ assert len (scheduler .waiting ) == 0
2059+ assert len (scheduler .running ) == 1
2060+
2061+ # Preempted request resumed in scheduled_cached_reqs
2062+ assert len (resumed_from_preemption ) == 1
2063+ assert len (scheduled_cached_reqs .resumed_req_token_ids ) == 1
2064+ assert resumed_from_preemption [0 ]
2065+ assert scheduled_cached_reqs .req_ids [0 ] == request_low .request_id
2066+ assert scheduled_cached_reqs .resumed_req_token_ids [0 ] is not None
2067+ # Resumed tokens include 30 prompt tokens and 2 decoded tokens
2068+ assert len (scheduled_cached_reqs .resumed_req_token_ids [0 ]) == 32
2069+ assert scheduled_cached_reqs .resumed_req_token_ids [0 ][31 ] == 100
2070+
20332071
20342072@pytest .mark .parametrize (
20352073 ("enable_chunked_prefill" , "is_encoder_decoder" , "expect_enabled" ),
0 commit comments