Skip to content

Commit

Permalink
[BUGFIX] Do not return ignored sentences twice in async llm engine (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Dec 26, 2023
1 parent face83c commit e0ff920
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 22 deletions.
10 changes: 4 additions & 6 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,18 @@ async def step_async(self) -> List[RequestOutput]:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
if scheduler_outputs.is_empty():
return ignored
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = await self._run_workers_async(
output = (await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
)) if not scheduler_outputs.is_empty() else []

return self._process_model_outputs(output, scheduler_outputs) + ignored
return self._process_model_outputs(output, scheduler_outputs)

async def _run_workers_async(
self,
Expand Down
19 changes: 3 additions & 16 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
Expand Down Expand Up @@ -328,16 +327,6 @@ def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()

def _schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[RequestOutput]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
return seq_group_metadata_list, scheduler_outputs, [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]

def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
Expand Down Expand Up @@ -586,9 +575,7 @@ def step(self) -> List[RequestOutput]:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
if scheduler_outputs.is_empty():
return ignored
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = self._run_workers(
Expand All @@ -597,7 +584,7 @@ def step(self) -> List[RequestOutput]:
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
) if not scheduler_outputs.is_empty() else []

return self._process_model_outputs(output, scheduler_outputs)

Expand Down

0 comments on commit e0ff920

Please sign in to comment.