@@ -1818,7 +1818,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
18181818    assert  len (scheduler .waiting ) ==  1 
18191819
18201820
1821- def  test_priority_scheduling_preemption_when_out_of_kv ():
1821+ def  test_priority_scheduling_preemption_and_resumption_when_out_of_kv ():
18221822    """Test that priority scheduling preempts lower priority requests 
18231823    when out of KV cache space.""" 
18241824    # Create scheduler with very limited memory to force preemption 
@@ -1827,6 +1827,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18271827        max_num_batched_tokens = 200 ,
18281828        num_blocks = 5 ,  # Can hold 64 tokens (first block is null) 
18291829        block_size = 16 ,  # Standard block size 
1830+         use_kv_connector = True ,
18301831    )
18311832
18321833    # Create a request and schedule it 
@@ -1838,12 +1839,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18381839        starting_idx = 0 ,
18391840    )[0 ]
18401841    scheduler .add_request (request_low )
1842+     # 1st schedule 
18411843    output  =  scheduler .schedule ()
18421844    assert  len (output .scheduled_new_reqs ) ==  1 
18431845    assert  len (scheduler .waiting ) ==  0 
18441846    assert  len (scheduler .running ) ==  1 
18451847
1846-     # Simulate model execution 
1848+     # Simulate model execution - 1st decode  
18471849    model_output  =  ModelRunnerOutput (
18481850        req_ids = [request_low .request_id ],
18491851        req_id_to_index = {request_low .request_id : 0 },
@@ -1864,6 +1866,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18641866        starting_idx = 1 ,
18651867    )[0 ]
18661868    scheduler .add_request (request_high )
1869+     # 2nd schedule 
18671870    output  =  scheduler .schedule ()
18681871    # KV cache should be full at this point 
18691872    assert  scheduler .kv_cache_manager .block_pool .get_num_free_blocks () ==  0 
@@ -1872,7 +1875,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18721875    assert  len (scheduler .waiting ) ==  0 
18731876    assert  len (scheduler .running ) ==  2 
18741877
1875-     # Simulate model execution 
1878+     # Simulate model execution - 2nd decode  
18761879    requests  =  [request_low , request_high ]
18771880    model_output  =  ModelRunnerOutput (
18781881        req_ids = [req .request_id  for  req  in  requests ],
@@ -1888,7 +1891,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18881891    )
18891892    scheduler .update_from_output (output , model_output )
18901893
1891-     # Schedule again  - this should trigger preemption 
1894+     # 3rd schedule  - this should trigger preemption 
18921895    # req_low needs 32 tokens = 2 blocks 
18931896    # req_high needs 33 tokens = 3 blocks 
18941897    # so doesn't fit in 4 blocks. 
@@ -1898,5 +1901,45 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18981901    assert  len (output .scheduled_new_reqs ) ==  0 
18991902    assert  output .scheduled_cached_reqs .num_reqs  ==  1 
19001903    assert  output .scheduled_cached_reqs .req_ids [0 ] ==  request_high .request_id 
1904+     assert  scheduler .requests [
1905+         request_low .request_id ].status  ==  RequestStatus .PREEMPTED 
19011906    assert  len (scheduler .waiting ) ==  1 
1902-     assert  len (scheduler .running ) ==  1 
1907+     assert  len (scheduler .running ) ==  1 
1908+ 
1909+     # Simulate model execution - 3rd decode 
1910+     model_output  =  ModelRunnerOutput (
1911+         req_ids = [req .request_id  for  req  in  requests ],
1912+         req_id_to_index = {
1913+             req .request_id : i 
1914+             for  i , req  in  enumerate (requests )
1915+         },
1916+         sampled_token_ids = [[], [100 ]],
1917+         # spec_token_ids=None, 
1918+         logprobs = None ,
1919+         prompt_logprobs_dict = {},
1920+         pooler_output = [],
1921+     )
1922+     # Finish the requests to make room for the preempted requests to resume 
1923+     scheduler .update_from_output (output , model_output )
1924+     scheduler .finish_requests (request_high .request_id ,
1925+                               RequestStatus .FINISHED_STOPPED )
1926+ 
1927+     # 4th Schedule again - this should trigger the resumption 
1928+     output  =  scheduler .schedule ()
1929+     scheduled_cached_reqs  =  output .scheduled_cached_reqs 
1930+     resumed_from_preemption  =  scheduled_cached_reqs .resumed_from_preemption 
1931+ 
1932+     assert  len (output .scheduled_new_reqs ) ==  0 
1933+     assert  scheduled_cached_reqs .num_reqs  ==  1 
1934+     assert  len (scheduler .waiting ) ==  0 
1935+     assert  len (scheduler .running ) ==  1 
1936+ 
1937+     # Preempted request resumed in scheduled_cached_reqs 
1938+     assert  len (resumed_from_preemption ) ==  1 
1939+     assert  len (scheduled_cached_reqs .resumed_req_token_ids ) ==  1 
1940+     assert  resumed_from_preemption [0 ]
1941+     assert  scheduled_cached_reqs .req_ids [0 ] ==  request_low .request_id 
1942+     assert  scheduled_cached_reqs .resumed_req_token_ids [0 ] is  not None 
1943+     # Resumed tokens include 30 prompt tokens and 2 decoded tokens 
1944+     assert  len (scheduled_cached_reqs .resumed_req_token_ids [0 ]) ==  32 
1945+     assert  scheduled_cached_reqs .resumed_req_token_ids [0 ][31 ] ==  100 
0 commit comments