Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chunked Prefill][4/n] Chunked prefill scheduler. #3853

Merged
merged 21 commits into from
Apr 5, 2024
Merged
Prev Previous commit
Next Next commit
addressed 2
  • Loading branch information
rkooo567 committed Apr 4, 2024
commit 7ee465538e25d869d341d0d7aef7f5a4991cf365
26 changes: 14 additions & 12 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,9 @@ def add_token_budget(budget: SchedulingBudget,
num_batched_tokens: int = 0,
num_curr_seqs: int = 0):
mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
budget.add_num_batched_tokens(mock_seq_group, num_batched_tokens)
budget.add_num_seqs(mock_seq_group, num_curr_seqs)
budget.add_num_batched_tokens(mock_seq_group.request_id,
num_batched_tokens)
budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


def test_prefill_schedule_max_prompt_len():
Expand Down Expand Up @@ -575,9 +576,10 @@ def test_decode_swap_beam_search():
scheduler._allocate_and_set_running(seq_group, 60)
running.append(seq_group)
append_new_token_seq_group(seq_group, 1)
budget.add_num_seqs(seq_group, seq_group.get_max_num_running_seqs())
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
budget.add_num_batched_tokens(
seq_group, seq_group.num_seqs(SequenceStatus.RUNNING))
seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))

# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
Expand Down Expand Up @@ -830,32 +832,32 @@ def test_scheduling_budget():

# Verify add/subtract num batched tokens.
_, seq_group = create_dummy_prompt("1", 3)
budget.add_num_batched_tokens(seq_group, 2)
budget.add_num_batched_tokens(seq_group.request_id, 2)
assert budget.remaining_token_budget() == 2
assert budget.num_batched_tokens == 2
assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
# Verify adding another seq group is no-op.
budget.add_num_batched_tokens(seq_group, 2)
budget.add_num_batched_tokens(seq_group.request_id, 2)
assert budget.remaining_token_budget() == 2
assert budget.num_batched_tokens == 2
budget.subtract_num_batched_tokens(seq_group, 2)
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
assert budget.remaining_token_budget() == 4
assert budget.num_batched_tokens == 0
budget.subtract_num_batched_tokens(seq_group, 2)
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
assert budget.remaining_token_budget() == 4
assert budget.num_batched_tokens == 0

# Verify add/subtract max seqs.
_, seq_group = create_dummy_prompt("1", 3)
budget.add_num_seqs(seq_group, 2)
budget.add_num_seqs(seq_group.request_id, 2)
assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
assert budget.num_curr_seqs == 2
# Verify adding another seq group is no-op.
budget.add_num_seqs(seq_group, 2)
budget.add_num_seqs(seq_group.request_id, 2)
assert budget.num_curr_seqs == 2
budget.subtract_num_seqs(seq_group, 2)
budget.subtract_num_seqs(seq_group.request_id, 2)
assert budget.num_curr_seqs == 0
budget.subtract_num_seqs(seq_group, 2)
budget.subtract_num_seqs(seq_group.request_id, 2)
assert budget.num_curr_seqs == 0
33 changes: 15 additions & 18 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,32 +48,27 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens

def add_num_batched_tokens(self, seq_group: SequenceGroup,
num_batched_tokens: int):
req_id = seq_group.request_id
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would this happen? I think this should raise an exception

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same case as #3853 (comment)

return

self._requeset_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens

def subtract_num_batched_tokens(self, seq_group: SequenceGroup,
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
req_id = seq_group.request_id
if req_id in self._requeset_ids_num_batched_tokens:
self._requeset_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens

def add_num_seqs(self, seq_group: SequenceGroup, num_curr_seqs: int):
req_id = seq_group.request_id
def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs:
return

self._requeset_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs

def subtract_num_seqs(self, seq_group: SequenceGroup, num_curr_seqs: int):
req_id = seq_group.request_id
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs:
self._requeset_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs
Expand Down Expand Up @@ -376,9 +371,10 @@ def _schedule_running(

running_queue.popleft()
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group,
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
budget.subtract_num_seqs(seq_group, num_running_seqs)
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)

Expand Down Expand Up @@ -419,8 +415,9 @@ def _schedule_running(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_running_tokens))
budget.add_num_batched_tokens(seq_group, num_running_tokens)
budget.add_num_seqs(seq_group, num_running_seqs)
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)

Expand Down Expand Up @@ -525,8 +522,8 @@ def _schedule_swapped(
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group, num_new_tokens)
budget.add_num_seqs(seq_group, num_new_seqs)
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)

swapped_queue.extendleft(leftover_swapped)

Expand Down Expand Up @@ -639,8 +636,8 @@ def _schedule_prefills(
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(seq_group, num_new_tokens)
budget.add_num_seqs(seq_group, num_new_seqs)
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)

# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
Expand Down Expand Up @@ -668,7 +665,7 @@ def _schedule_default(self) -> SchedulerOutputs:
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for seq_group in self.running:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug (we should not include num_batched tokens in the beginning). it is fixed & regression tests are added.

budget.add_num_seqs(seq_group,
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = set(
seq_group.lora_int_id
Expand Down
Loading