Skip to content

Commit caac5c2

Browse files
authored
[Bugfix][Core] fix abort_seq_group and memory leak when n>1 (#14326)
Signed-off-by: courage17340 <courage17340@163.com>
1 parent 6bd1dd9 commit caac5c2

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

vllm/core/scheduler.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from vllm.lora.request import LoRARequest
1717
from vllm.prompt_adapter.request import PromptAdapterRequest
1818
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
19-
SequenceGroupMetadata, SequenceGroupMetadataDelta,
20-
SequenceStage, SequenceStatus)
19+
SequenceGroupBase, SequenceGroupMetadata,
20+
SequenceGroupMetadataDelta, SequenceStage,
21+
SequenceStatus)
2122
from vllm.utils import Device, PyObjectCache
2223

2324
logger = init_logger(__name__)
@@ -561,7 +562,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
561562
# Only for testing purposes.
562563
self.swapped.append(seq_group)
563564

564-
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
565+
def abort_seq_group(
566+
self,
567+
request_id: Union[str, Iterable[str]],
568+
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
569+
) -> None:
565570
"""Aborts a sequence group with the given ID.
566571
567572
Check if the sequence group with the given ID
@@ -573,21 +578,29 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
573578
574579
Args:
575580
request_id: The ID(s) of the sequence group to abort.
581+
seq_id_to_seq_group: helper for groups with n>1
576582
"""
577583
if isinstance(request_id, str):
578584
request_id = (request_id, )
579585
request_ids = set(request_id)
586+
seq_id_to_seq_group = seq_id_to_seq_group or {}
580587
for state_queue in [self.waiting, self.running, self.swapped]:
581588
aborted_groups: List[SequenceGroup] = []
582589
for seq_group in state_queue:
583-
if not request_ids:
584-
# Using 'break' here may add two extra iterations,
585-
# but is acceptable to reduce complexity.
586-
break
587-
if seq_group.request_id in request_ids:
590+
# When n>1, seq_group.request_id looks like
591+
# foo_parallel_sample_0, while request_ids is just foo, and we
592+
# should resolve it as real_request_id to match.
593+
if seq_group.request_id in seq_id_to_seq_group:
594+
real_request_id = seq_id_to_seq_group[
595+
seq_group.request_id].group_id
596+
else:
597+
real_request_id = seq_group.request_id
598+
if real_request_id in request_ids:
588599
# Appending aborted group into pending list.
589600
aborted_groups.append(seq_group)
590-
request_ids.remove(seq_group.request_id)
601+
# We can't remove real_request_id in request_ids here,
602+
# because there may be other seq groups sharing the same
603+
# real_request_id
591604
for aborted_group in aborted_groups:
592605
# Remove the sequence group from the state queue.
593606
state_queue.remove(aborted_group)
@@ -598,6 +611,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
598611
continue
599612
seq.status = SequenceStatus.FINISHED_ABORTED
600613
self.free_seq(seq)
614+
if aborted_group.request_id in seq_id_to_seq_group:
615+
del seq_id_to_seq_group[aborted_group.request_id]
601616

602617
self._free_seq_group_cross_attn_blocks(aborted_group)
603618

vllm/engine/llm_engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,8 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
887887
>>> engine.abort_request(request_id)
888888
"""
889889
for scheduler in self.scheduler:
890-
scheduler.abort_seq_group(request_id)
890+
scheduler.abort_seq_group(
891+
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
891892

892893
def get_model_config(self) -> ModelConfig:
893894
"""Gets the model configuration."""
@@ -1354,6 +1355,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
13541355

13551356
finished_requests_ids = self.scheduler[
13561357
virtual_engine].get_and_reset_finished_requests_ids()
1358+
# When n>1, elements in self.seq_id_to_seq_group should be deleted
1359+
# here, otherwise memory leaks.
1360+
for finished_request_id in finished_requests_ids:
1361+
if finished_request_id in self.seq_id_to_seq_group:
1362+
del self.seq_id_to_seq_group[finished_request_id]
13571363

13581364
# Maybe switch from async mode to sync mode
13591365
if not allow_async_output_proc and len(ctx.output_queue) > 0:

0 commit comments

Comments
 (0)