@@ -340,3 +340,84 @@ def test_full_block_prompt():
340340 output = outputs [0 ]
341341 assert output .finish_reason == FinishReason .STOP
342342 assert_scheduler_empty (scheduler )
343+
344+
345+ def test_cannot_schedule_after_recv ():
346+ """
347+ Test that we can handle no schedule after recv due to not
348+ enough remaining KV blocks.
349+ """
350+
351+ # NOTE: the KVCacheManager will use 1 null block.
352+ # So there are 5 total working blocks.
353+ TOTAL_NUM_BLOCKS = 6
354+ vllm_config = create_vllm_config ()
355+ scheduler = create_scheduler (vllm_config , num_blocks = TOTAL_NUM_BLOCKS )
356+
357+ # Prime the KVCache.
358+ NUM_PROMPT_BLOCKS = 2
359+ BLOCK_SIZE = vllm_config .cache_config .block_size
360+ # Prompt will use 2 blocks + 1 block after we schedule.
361+ NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
362+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5 ))
363+
364+ request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
365+ request_remote = create_request (request_id = 2 ,
366+ num_tokens = NUM_TOKENS_REMOTE ,
367+ do_remote_prefill = True )
368+
369+ # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
370+ scheduler .add_request (request_normal )
371+ scheduler_output = scheduler .schedule ()
372+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
373+ scheduler .update_from_output (scheduler_output , model_runner_output )
374+ assert len (scheduler .running ) == 1
375+ assert len (scheduler .waiting ) == 0
376+
377+ # Step 2: 5 blocks are in use (2 new for remote blocks).
378+ scheduler .add_request (request_remote )
379+ scheduler_output = scheduler .schedule ()
380+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
381+ scheduler .update_from_output (scheduler_output , model_runner_output )
382+ assert len (scheduler .running ) == 1
383+ assert len (scheduler .waiting ) == 1
384+
385+ # Step 3: finish recving (5 blocks in use)
386+ scheduler_output = scheduler .schedule ()
387+ model_runner_output = create_model_runner_output (
388+ reqs = [request_normal ], finished_recving = [request_remote .request_id ])
389+ scheduler .update_from_output (scheduler_output , model_runner_output )
390+ assert len (scheduler .running ) == 1
391+ assert len (scheduler .waiting ) == 1
392+
393+ # Step 4: try to schedule, not enough blocks.
394+ scheduler_output = scheduler .schedule ()
395+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
396+ scheduler .update_from_output (scheduler_output , model_runner_output )
397+ assert len (scheduler .running ) == 1
398+ assert len (scheduler .waiting ) == 1
399+
400+ # Step 5: finish the request, free it.
401+ scheduler_output = scheduler .schedule ()
402+ model_runner_output = create_model_runner_output (reqs = [request_normal ],
403+ use_eos = True )
404+ scheduler .update_from_output (scheduler_output , model_runner_output )
405+ assert len (scheduler .running ) == 0
406+ assert len (scheduler .waiting ) == 1
407+
408+ # Step 6: now we can schedule (with 2 blocks computed).
409+ scheduler_output = scheduler .schedule ()
410+ model_runner_output = create_model_runner_output (reqs = [request_remote ])
411+ assert (scheduler_output .scheduled_new_reqs [0 ].num_computed_tokens ==
412+ NUM_PROMPT_BLOCKS * BLOCK_SIZE )
413+ scheduler .update_from_output (scheduler_output , model_runner_output )
414+ assert len (scheduler .running ) == 1
415+ assert len (scheduler .waiting ) == 0
416+
417+ # Step 7: free everything.
418+ scheduler_output = scheduler .schedule ()
419+ model_runner_output = create_model_runner_output (reqs = [request_remote ],
420+ use_eos = True )
421+ scheduler .update_from_output (scheduler_output , model_runner_output )
422+ _ = scheduler .schedule ()
423+ assert_scheduler_empty (scheduler )
0 commit comments