Skip to content

Commit f42cb9d

Browse files
afeldman-nmVarun Sundar Rabindranathabf149
authored andcommitted
[Core] Combined support for multi-step scheduling, chunked prefill & prefix caching (vllm-project#8804)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent a3e373f commit f42cb9d

File tree

3 files changed

+180
-17
lines changed

3 files changed

+180
-17
lines changed

tests/multi_step/test_correctness_llm.py

+158
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Test the LLMEngine with multi-step-decoding
22

3+
import copy
34
from typing import Optional
45

56
import pytest
@@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs(
196197
name_0="hf",
197198
name_1="vllm",
198199
)
200+
201+
202+
@pytest.mark.parametrize("model", MODELS)
203+
@pytest.mark.parametrize("dtype", ["half"])
204+
@pytest.mark.parametrize("tp_size", [1])
205+
@pytest.mark.parametrize("max_tokens", [5])
206+
@pytest.mark.parametrize("enforce_eager", [True])
207+
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
208+
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
209+
@pytest.mark.parametrize("num_logprobs", [None, 5])
210+
def test_multi_step_llm_chunked_prefill_prefix_cache(
211+
vllm_runner,
212+
example_prompts,
213+
model: str,
214+
dtype: str,
215+
tp_size: int,
216+
max_tokens: int,
217+
enforce_eager: int,
218+
num_scheduler_steps: int,
219+
num_prompts: int,
220+
num_logprobs: Optional[int],
221+
) -> None:
222+
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
223+
224+
Set up contrived scenario which tests for a possible failure mode of
225+
scheduling with multi-step+"single-step chunked prefill"+APC
226+
227+
"single-step chunked prefill" here refers to the current vLLM multi-step+
228+
chunked-prefill implementation, which requires that a prefill may only
229+
be scheduled in the same step as decodes if the prefill prompt fits in a
230+
single chunk (note that "complete" multi-step+chunked-prefill would allow
231+
a prefill to span multiple chunks & multiple steps but that is not yet
232+
the case.)
233+
234+
"APC" is short for "automatic prefix caching".
235+
236+
This test creates a scenario where the scheduler must decide whether/how
237+
to schedule a prefill with a prompt that exceeds the available token budget.
238+
The correct behavior for multi-step+"single-step chunked prefill"+APC is to
239+
put off scheduling the prefill until a future step.
240+
241+
Validate that:
242+
* Multi-step kernels do not raise an exception due to incorrect scheduler
243+
behavior
244+
* Generated tokens match between
245+
multi-step+"single-step chunked prefill"+APC and
246+
single-step scheduling.
247+
* (If logprobs are enabled) check logprobs are close enough
248+
249+
Args:
250+
vllm_runner: vLLM model runner fixture
251+
example_prompts: test fixture providing example prompts
252+
model: model under test (same for single- and multi-step engines)
253+
dtype: tensor datatype for engine to utilize
254+
tp_size: degree of tensor-parallelism
255+
max_tokens: the maximum number of tokens to generate
256+
enforce_eager
257+
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
258+
GPU -> CPU output transfer
259+
num_prompts: number of example prompts under test
260+
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
261+
completions endpoint; `None` -> 1 logprob returned.
262+
"""
263+
264+
# Set up contrived test for correct scheduling behavior with
265+
# multi-step+"single-step chunked prefill"+APC.
266+
#
267+
# Assume block_size=16
268+
#
269+
# Assume max_num_batched_tokens=48
270+
# => Per-step token budget=48
271+
#
272+
# 1. Scheduler schedules 0th prompt (24 tokens)
273+
# => Remaining token budget=24
274+
# 2. Scheduler attempts to schedule 1st prompt (30 tokens)
275+
# * 30 tokens exceeds 24 token remaining budget
276+
# * Correct behavior: do not schedule this prompt in this step
277+
# * Incorrect behavior: schedule prompt chunk
278+
# * `do_sample=False` for this prompt in this step
279+
# * Chunk size = (remaining tokens // block size) * block size
280+
#
281+
# The Incorrect scheduling behavior - if it occurs - will cause an exception
282+
# in the model runner resulting from `do_sample=False`.
283+
assert len(example_prompts) >= 2
284+
challenge_prompts = copy.deepcopy(example_prompts)
285+
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
286+
'inference and serving engine for LLMs.\n'
287+
) # 24 tok
288+
challenge_prompts[1] = (
289+
'Briefly describe the major milestones in the '
290+
'development of artificial intelligence from 1950 to 2020.\n'
291+
) # 30 tok
292+
293+
# If necessary, adjust the length of `challenge_prompts` to match
294+
# `num_prompts`
295+
if len(challenge_prompts) < num_prompts:
296+
challenge_prompts = (challenge_prompts *
297+
((num_prompts // len(challenge_prompts)) + 1))
298+
challenge_prompts = challenge_prompts[:num_prompts]
299+
assert len(challenge_prompts) == num_prompts
300+
301+
# Single-step scheduler baseline
302+
with vllm_runner(
303+
model,
304+
dtype=dtype,
305+
enforce_eager=enforce_eager,
306+
gpu_memory_utilization=0.7,
307+
tensor_parallel_size=tp_size,
308+
use_v2_block_manager=True,
309+
num_scheduler_steps=num_scheduler_steps,
310+
max_model_len=48,
311+
max_num_batched_tokens=48,
312+
max_num_seqs=4,
313+
block_size=16,
314+
) as vllm_model:
315+
outputs_baseline = (vllm_model.generate_greedy(
316+
challenge_prompts, max_tokens) if num_logprobs is None else
317+
vllm_model.generate_greedy_logprobs(
318+
challenge_prompts, max_tokens, num_logprobs))
319+
320+
# multi-step+"single-step chunked prefill"+APC
321+
with vllm_runner(
322+
model,
323+
dtype=dtype,
324+
enforce_eager=enforce_eager,
325+
gpu_memory_utilization=0.7,
326+
tensor_parallel_size=tp_size,
327+
use_v2_block_manager=True,
328+
enable_chunked_prefill=True,
329+
enable_prefix_caching=True,
330+
num_scheduler_steps=num_scheduler_steps,
331+
max_model_len=48,
332+
max_num_batched_tokens=48,
333+
max_num_seqs=4,
334+
block_size=16,
335+
) as vllm_model:
336+
outputs_w_features = (vllm_model.generate_greedy(
337+
challenge_prompts, max_tokens) if num_logprobs is None else
338+
vllm_model.generate_greedy_logprobs(
339+
challenge_prompts, max_tokens, num_logprobs))
340+
341+
if num_logprobs is None:
342+
# No-logprobs test
343+
check_outputs_equal(
344+
outputs_0_lst=outputs_baseline,
345+
outputs_1_lst=outputs_w_features,
346+
name_0="multi-step",
347+
name_1="multi-step+features",
348+
)
349+
else:
350+
# Yes-logprobs test
351+
check_logprobs_close(
352+
outputs_0_lst=outputs_baseline,
353+
outputs_1_lst=outputs_w_features,
354+
name_0="multi-step",
355+
name_1="multi-step+features",
356+
)

vllm/core/scheduler.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -1607,10 +1607,29 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
16071607
# in a decode phase. Do not chunk.
16081608
if enable_chunking and len(seqs) == 1:
16091609
remaining_token_budget = budget.remaining_token_budget()
1610-
if self.cache_config.enable_prefix_caching:
1610+
if self.scheduler_config.is_multi_step:
1611+
# The current multi-step + chunked prefill capability does
1612+
# not actually support chunking prompts.
1613+
#
1614+
# Therefore, `num_new_tokens` is computed in the same fashion
1615+
# for both multi-step+chunked-prefill &
1616+
# multi-step+chunked-prefill+APC
1617+
#
1618+
# Prompts with more tokens than the current remaining budget
1619+
# are postponed to future scheduler steps
1620+
if num_new_tokens > self._get_prompt_limit(seq_group):
1621+
# If the seq_group is in prompt-stage, pass the
1622+
# num_new_tokens as-is so the caller can ignore
1623+
# the sequence.
1624+
pass
1625+
else:
1626+
num_new_tokens = 0 \
1627+
if num_new_tokens > remaining_token_budget \
1628+
else num_new_tokens
1629+
elif self.cache_config.enable_prefix_caching:
16111630
# When prefix caching is enabled, we always allocate
1612-
# the number of new tokens that is dividable by the block size
1613-
# to avoid partial block matching.
1631+
# the number of new tokens that is dividable by the block
1632+
# size to avoid partial block matching.
16141633
block_size = self.cache_config.block_size
16151634
remainder = budget.token_budget % block_size
16161635
if remainder != 0:
@@ -1623,16 +1642,6 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
16231642
if remaining_token_budget < num_new_tokens:
16241643
num_new_tokens = (remaining_token_budget //
16251644
block_size) * block_size
1626-
elif self.scheduler_config.is_multi_step:
1627-
if num_new_tokens > self._get_prompt_limit(seq_group):
1628-
# If the seq_group is in prompt-stage, pass the
1629-
# num_new_tokens as-is so the caller can ignore
1630-
# the sequence.
1631-
pass
1632-
else:
1633-
num_new_tokens = 0 \
1634-
if num_new_tokens > remaining_token_budget \
1635-
else num_new_tokens
16361645
else:
16371646
num_new_tokens = min(num_new_tokens, remaining_token_budget)
16381647
return num_new_tokens

vllm/engine/arg_utils.py

-4
Original file line numberDiff line numberDiff line change
@@ -999,10 +999,6 @@ def create_engine_config(self) -> EngineConfig:
999999
if speculative_config is not None:
10001000
raise ValueError("Speculative decoding is not supported with "
10011001
"multi-step (--num-scheduler-steps > 1)")
1002-
if self.enable_chunked_prefill and self.enable_prefix_caching:
1003-
raise ValueError("Multi-Step is not supported with "
1004-
"both Chunked-Prefill and Prefix-Caching "
1005-
"enabled together.")
10061002
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
10071003
raise ValueError("Multi-Step Chunked-Prefill is not supported "
10081004
"for pipeline-parallel-size > 1")

0 commit comments

Comments
 (0)