Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chunked Prefill][4/n] Chunked prefill scheduler. #3853

Merged
merged 21 commits into from
Apr 5, 2024
Merged
557 changes: 557 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py

Large diffs are not rendered by default.

220 changes: 174 additions & 46 deletions tests/core/test_scheduler.py

Large diffs are not rendered by default.

58 changes: 55 additions & 3 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
import time
from typing import Optional

import pytest

from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
SequenceGroup, SequenceGroupOutput, SequenceOutput)


def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length

# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)

return seq_group


@pytest.fixture
Expand Down Expand Up @@ -67,6 +96,29 @@ def test_sequence_data_prefill():

# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0


def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ def __init__(
self._verify_args()

def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
Expand Down
4 changes: 1 addition & 3 deletions vllm/core/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def get_priority(

class PolicyFactory:

_POLICY_REGISTRY = {
'fcfs': FCFS,
}
_POLICY_REGISTRY = {'fcfs': FCFS}

@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
Expand Down
Loading
Loading