Skip to content

[V1] Eagerly remove finished requests from the batch #14388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,24 @@ def test_engine_core(monkeypatch):
engine_core.add_request(req)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()

_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()

engine_core.abort_requests([request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
assert not engine_core.scheduler.has_unfinished_requests()
assert engine_core.scheduler.has_finished_requests()

_ = engine_core.step()
assert not engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()

# Add, step, abort 1 of the 3.
req0 = make_request()
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
engine_core_outputs = client.get_output().outputs

if len(engine_core_outputs) == 0:
break
continue

all_finished = True
for out in engine_core_outputs:
Expand All @@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
engine_core_outputs = (await client.get_output_async()).outputs

if len(engine_core_outputs) == 0:
break
continue

all_finished = True
for out in engine_core_outputs:
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ def finish_requests(
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
request_ids = set(request_ids)
else:
request_ids = set(request_ids)

for req_id in request_ids:
request = self.requests.get(req_id)
Expand Down Expand Up @@ -657,6 +658,14 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0

def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0

def has_requests(self):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()

def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,14 @@ async def _run_output_handler(self):
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)

iteration_stats = IterationStats() if self.log_stats else None
iteration_stats = IterationStats() if (
self.log_stats and num_outputs) else None

# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
Expand Down Expand Up @@ -315,7 +316,6 @@ def _record_stats(
return

assert scheduler_stats is not None
assert iteration_stats is not None
for stat_logger in self.stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def abort_requests(self, request_ids: list[str]):
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please document here why we should do has_requests rather than has_unfinished_requests (like the PR description)?

return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_output = self.scheduler.schedule()
Expand Down Expand Up @@ -315,7 +317,7 @@ def run_busy_loop(self):
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
while not self.scheduler.has_unfinished_requests():
while not self.scheduler.has_requests():
logger.debug("EngineCore busy loop waiting.")
req = self.input_queue.get()
self._handle_client_request(*req)
Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class StatLoggerBase(ABC):

@abstractmethod
def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
...

def log(self): # noqa
Expand Down Expand Up @@ -56,10 +56,11 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_log_time))

def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output."""

self._track_iteration_stats(iteration_stats)
if iteration_stats:
self._track_iteration_stats(iteration_stats)

self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)

Expand Down Expand Up @@ -319,7 +320,7 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
info_gauge.set(1)

def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
"""Log to prometheus."""
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
Expand All @@ -331,6 +332,9 @@ def record(self, scheduler_stats: SchedulerStats,
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)

if iteration_stats is None:
return

self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,13 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]


EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
9 changes: 6 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
Expand Down Expand Up @@ -867,6 +868,9 @@ def execute_model(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: why not return None?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just curious: is this compatible with PP? cc @comaniac

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: why not return None?

I could do but I think it would actually be more intrusive, i.e. change the execute_model signature and require additional check in the step() method. This in some way is more natural since it's just "empty" output which will still get passed to self.scheduler.update_from_output() which is fine (it will do nothing apart from return the EngineCoreOutputs with empty stats).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit I haven't fully digested the PP impl ... so yes would be great if @comaniac could comment on this!

Copy link
Collaborator

@comaniac comaniac Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems compatible but may be further cleaned up. In short, when PP>1, I guarantee only the batches with scheduler_output.total_num_scheduled_tokens > 0 will be executed: https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/core.py#L181.

Specifically since .schedule() may schedule nothing when PP>1 (because all requests are on the fly), we already have the case that scheduler_output.total_num_scheduled_tokens may be 0.


if self.is_multimodal_model:
# Run the multimodal encoder if any.
Expand Down Expand Up @@ -1013,15 +1017,14 @@ def execute_model(
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)

model_runner_output = ModelRunnerOutput(
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output

def generate_draft_token_ids(
self,
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

Expand Down Expand Up @@ -547,6 +548,9 @@ def execute_model(
) -> ModelRunnerOutput:
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT

if self.is_multimodal_model:
# Run the multimodal encoder if any.
Expand Down