Skip to content

Commit

Permalink
[2/N] Chunked prefill data update (vllm-project#3538)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Mar 28, 2024
1 parent ce567a2 commit b51c1cc
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 76 deletions.
14 changes: 13 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
download_dir=args.download_dir)
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)

sampling_params = SamplingParams(
n=args.n,
Expand Down Expand Up @@ -145,6 +147,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -266,6 +268,8 @@ def __init__(
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs,
)

Expand Down
22 changes: 13 additions & 9 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from .utils import create_dummy_prompt


def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1)
Expand Down Expand Up @@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
running: List[SequenceGroup] = []
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group)
Expand All @@ -68,15 +72,15 @@ def test_scheduler_schedule_simple():
# Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group

# Schedule seq groups generation.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_seq_group
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand All @@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():

# Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand All @@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():

# Schedule seq groups generation and preempt seq group b.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a]
assert get_sequence_groups(out) == [seq_group_a]
assert out.num_batched_tokens == 1
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand All @@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b]
assert get_sequence_groups(out) == [seq_group_b]
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
Expand Down Expand Up @@ -155,11 +159,11 @@ def test_scheduler_max_seqs():

# Schedule seq groups prompts.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])

# Schedule seq groups generation.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])

# Append 2 more seq group
scheduler.add_seq_group(all_seq_groups[1])
Expand All @@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
# Only 1 seq group should be scheduled since max_seq_group is 2
# and one is prompting.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])


def test_scheduler_delay_factor():
Expand Down
24 changes: 23 additions & 1 deletion tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)


@pytest.fixture
Expand Down Expand Up @@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3


def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
seq_data.update_num_computed_tokens(2)
assert seq_data.get_num_uncomputed_tokens() == 2
assert seq_data.get_num_computed_tokens() == 2

# advance by 1
seq_data.update_num_computed_tokens(1)
assert seq_data.get_num_uncomputed_tokens() == 1
assert seq_data.get_num_computed_tokens() == 3

# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0
37 changes: 20 additions & 17 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
))
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)

expected_selected_token_indices = []
selected_token_start_idx = 0
Expand Down Expand Up @@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
seq_data = SequenceData(seq_data)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)

input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
Expand Down
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ class SchedulerConfig:
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
"""

def __init__(
Expand All @@ -542,6 +544,7 @@ def __init__(
max_model_len: int,
use_v2_block_manager: bool = False,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -553,6 +556,7 @@ def __init__(
self.max_model_len = max_model_len
self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager
self.chunked_prefill_enabled = enable_chunked_prefill
self._verify_args()

def _verify_args(self) -> None:
Expand Down
Loading

0 comments on commit b51c1cc

Please sign in to comment.