Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chunked Prefill][4/n] Chunked prefill scheduler. #3853

Merged
merged 21 commits into from
Apr 5, 2024
Merged
557 changes: 557 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py

Large diffs are not rendered by default.

218 changes: 172 additions & 46 deletions tests/core/test_scheduler.py

Large diffs are not rendered by default.

58 changes: 55 additions & 3 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
import time
from typing import Optional

import pytest

from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
SequenceGroup, SequenceGroupOutput, SequenceOutput)


def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length

# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)

return seq_group


@pytest.fixture
Expand Down Expand Up @@ -67,6 +96,29 @@ def test_sequence_data_prefill():

# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0


def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ def __init__(
self._verify_args()

def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
Expand Down
21 changes: 18 additions & 3 deletions vllm/core/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,26 @@ def get_priority(
return now - seq_group.metrics.arrival_time


class MaximalDecoding(Policy):
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
"""Policy to prioritize decoding requests as much as possible when a
queue contains both prefill and decode requests.

It prioritizes 1. seq_group with small number of tokens to compute
(i.e., decode). 2. FCFS.
"""

def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return (-seq_group.get_num_uncomputed_tokens(),
now - seq_group.metrics.arrival_time)
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved


class PolicyFactory:

_POLICY_REGISTRY = {
'fcfs': FCFS,
}
_POLICY_REGISTRY = {'fcfs': FCFS, 'maximal_decoding': MaximalDecoding}

@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
Expand Down
Loading
Loading