Skip to content

Commit

Permalink
[Scheduler] Warning upon preemption and Swapping (vllm-project#4647)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
  • Loading branch information
rkooo567 and robertgshaw2-neuralmagic authored May 13, 2024
1 parent 350f9e1 commit e7c46b9
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 3 deletions.
19 changes: 19 additions & 0 deletions docs/source/models/performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@
Performance and Tuning
======================

Preemption
----------
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
available again. When this occurs, the following warning is printed:

```
WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1
```

While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency.
If you frequently encounter preemptions from the vLLM engine, consider the following actions:

- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space.
- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space.
- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache.

You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False.

Chunked Prefill
---------------
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests.
Expand Down
44 changes: 43 additions & 1 deletion tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
pytest tests/basic_correctness/test_preemption.py`.
"""
import pytest
from prometheus_client import REGISTRY

from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
Expand All @@ -87,10 +89,13 @@ def test_preemption(
vllm_model = vllm_runner(
model,
dtype=dtype,
disable_log_stats=False,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
del vllm_model

for i in range(len(example_prompts)):
Expand All @@ -100,13 +105,28 @@ def test_preemption(
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
Expand All @@ -122,11 +142,18 @@ def test_swap(
max_tokens)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
vllm_model = vllm_runner(
model,
dtype=dtype,
swap_space=10,
disable_log_stats=False,
)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
del vllm_model

for i in range(len(example_prompts)):
Expand All @@ -138,6 +165,21 @@ def test_swap(
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

assert ("is preempted by PreemptionMode.SWAP mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


@pytest.fixture()
def temporary_enable_log_propagate():
import logging
logger = logging.getLogger("vllm")
logger.propagate = True
yield
logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog
1 change: 1 addition & 0 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort():
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
assert scheduler.get_num_unfinished_seq_groups() == 2
assert out.preempted == 1

# Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler.abort_seq_group("1")
Expand Down
18 changes: 18 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class SchedulerOutputs:
num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
preempted: int

def __post_init__(self):
# Swap in and swap out should never happen at the same time.
Expand Down Expand Up @@ -310,6 +311,7 @@ def __init__(
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0

@property
def lora_enabled(self) -> bool:
Expand Down Expand Up @@ -785,6 +787,8 @@ def _schedule_default(self) -> SchedulerOutputs:
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))

# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
Expand All @@ -804,6 +808,7 @@ def _schedule_default(self) -> SchedulerOutputs:
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=preempted,
)

def _schedule_chunked_prefill(self):
Expand Down Expand Up @@ -891,6 +896,8 @@ def _schedule_chunked_prefill(self):
ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
)

def _schedule(self) -> SchedulerOutputs:
Expand Down Expand Up @@ -1057,6 +1064,17 @@ def _preempt(
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP

if self.num_cumulative_preemption % 50 == 0:
logger.warning(
"Sequence group %s is preempted by %s mode because there is "
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d", seq_group.request_id,
preemption_mode, self.num_cumulative_preemption + 1)
self.num_cumulative_preemption += 1

if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def _get_stats(
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else
scheduler_outputs.preempted)

# Request stats
# Latency
Expand Down Expand Up @@ -830,7 +832,6 @@ def _get_stats(

return Stats(
now=now,

# System stats
# Scheduler State
num_running_sys=num_running_sys,
Expand All @@ -846,6 +847,7 @@ def _get_stats(
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter,

# Request stats
# Latency
Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __init__(self, labelnames: List[str], max_model_len: int):
labelnames=labelnames)

# Iteration stats
self.counter_num_preemption = Counter(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
Expand Down Expand Up @@ -181,6 +185,7 @@ class Stats:
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int

# Request stats (should have _requests suffix)
# Latency
Expand Down Expand Up @@ -244,6 +249,8 @@ def _log_prometheus(self, stats: Stats) -> None:
stats.cpu_cache_usage_sys)

# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
Expand Down Expand Up @@ -336,7 +343,7 @@ def log(self, stats: Stats) -> None:
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%",
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
Expand Down

0 comments on commit e7c46b9

Please sign in to comment.