Skip to content

Commit

Permalink
[Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.…
Browse files Browse the repository at this point in the history
…n is not 1 and max_tokens is large & Add tests for preemption (vllm-project#4451)
  • Loading branch information
rkooo567 authored May 2, 2024
1 parent b8afa8b commit 0d62fe5
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 13 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py

- label: Core Test
command: pytest -v -s core
Expand Down
1 change: 0 additions & 1 deletion tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_models(
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
Expand Down
138 changes: 138 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Compare the short outputs of HF and vLLM when using greedy sampling.
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
pytest tests/basic_correctness/test_preemption.py`.
"""
import pytest

from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)

MODELS = [
"facebook/opt-125m",
]

assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
"tests/basic_correctness/test_preemption.py`")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_chunked_prefill_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
"""By default, recompute preemption is enabled"""

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model

for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,15 @@ def __init__(
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space=4,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,
swap_space=0,
swap_space=swap_space,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
Expand Down
6 changes: 3 additions & 3 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)

exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

execute_model_data, _, _ = create_batch(batch_size, k)
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)

exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)

with pytest.raises(ValueError, match=exception_secret):
Expand Down Expand Up @@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):

target_worker.execute_model.return_value = [target_output[0]]

exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret)

with pytest.raises(ValueError, match=exception_secret):
Expand Down
36 changes: 28 additions & 8 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import enum
import os
import random
import time
from collections import deque
from dataclasses import dataclass, field
Expand All @@ -15,6 +17,13 @@

logger = init_logger(__name__)

# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500


class PreemptionMode(enum.Enum):
"""Preemption modes.
Expand Down Expand Up @@ -286,6 +295,13 @@ def __init__(
# Latency of the last prompt step
self.last_prompt_latency = 0.0

# The following field is test-only. It is used to inject artificial
# preemption.
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)

@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
Expand Down Expand Up @@ -386,15 +402,13 @@ def _schedule_running(
# groups to preempt.
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)

while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

# We can have up to 1 running prefill at any given time in running
# queue, which means we can guarantee chunk size is at least 1.
assert num_running_tokens != 0
if num_running_tokens == 0:
break

running_queue.popleft()
while not self._can_append_slots(seq_group):
Expand Down Expand Up @@ -449,9 +463,6 @@ def _schedule_running(
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)

# Make sure all queues are updated.
assert len(running_queue) == 0

return running_queue, SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
Expand Down Expand Up @@ -545,7 +556,6 @@ def _schedule_swapped(
ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens))
else:
assert num_new_tokens == 1
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
Expand Down Expand Up @@ -868,6 +878,13 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
# It is True only for testing case to trigger artificial preemption.
if (self.enable_artificial_preemption
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
and self.artificial_preempt_cnt > 0):
self.artificial_preempt_cnt -= 1
return False

# Appending slots only occurs in decoding.
is_prefill = False

Expand Down Expand Up @@ -1116,11 +1133,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
"""
num_new_tokens = 0
seqs = seq_group.get_seqs(status=status)
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
Expand Down

0 comments on commit 0d62fe5

Please sign in to comment.