16
16
from vllm .lora .request import LoRARequest
17
17
from vllm .prompt_adapter .request import PromptAdapterRequest
18
18
from vllm .sequence import (Sequence , SequenceData , SequenceGroup ,
19
- SequenceGroupMetadata , SequenceGroupMetadataDelta ,
20
- SequenceStage , SequenceStatus )
19
+ SequenceGroupBase , SequenceGroupMetadata ,
20
+ SequenceGroupMetadataDelta , SequenceStage ,
21
+ SequenceStatus )
21
22
from vllm .utils import Device , PyObjectCache
22
23
23
24
logger = init_logger (__name__ )
@@ -561,7 +562,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
561
562
# Only for testing purposes.
562
563
self .swapped .append (seq_group )
563
564
564
- def abort_seq_group (self , request_id : Union [str , Iterable [str ]]) -> None :
565
+ def abort_seq_group (
566
+ self ,
567
+ request_id : Union [str , Iterable [str ]],
568
+ seq_id_to_seq_group : Optional [Dict [str , SequenceGroupBase ]] = None ,
569
+ ) -> None :
565
570
"""Aborts a sequence group with the given ID.
566
571
567
572
Check if the sequence group with the given ID
@@ -573,21 +578,29 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
573
578
574
579
Args:
575
580
request_id: The ID(s) of the sequence group to abort.
581
+ seq_id_to_seq_group: helper for groups with n>1
576
582
"""
577
583
if isinstance (request_id , str ):
578
584
request_id = (request_id , )
579
585
request_ids = set (request_id )
586
+ seq_id_to_seq_group = seq_id_to_seq_group or {}
580
587
for state_queue in [self .waiting , self .running , self .swapped ]:
581
588
aborted_groups : List [SequenceGroup ] = []
582
589
for seq_group in state_queue :
583
- if not request_ids :
584
- # Using 'break' here may add two extra iterations,
585
- # but is acceptable to reduce complexity.
586
- break
587
- if seq_group .request_id in request_ids :
590
+ # When n>1, seq_group.request_id looks like
591
+ # foo_parallel_sample_0, while request_ids is just foo, and we
592
+ # should resolve it as real_request_id to match.
593
+ if seq_group .request_id in seq_id_to_seq_group :
594
+ real_request_id = seq_id_to_seq_group [
595
+ seq_group .request_id ].group_id
596
+ else :
597
+ real_request_id = seq_group .request_id
598
+ if real_request_id in request_ids :
588
599
# Appending aborted group into pending list.
589
600
aborted_groups .append (seq_group )
590
- request_ids .remove (seq_group .request_id )
601
+ # We can't remove real_request_id in request_ids here,
602
+ # because there may be other seq groups sharing the same
603
+ # real_request_id
591
604
for aborted_group in aborted_groups :
592
605
# Remove the sequence group from the state queue.
593
606
state_queue .remove (aborted_group )
@@ -598,6 +611,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
598
611
continue
599
612
seq .status = SequenceStatus .FINISHED_ABORTED
600
613
self .free_seq (seq )
614
+ if aborted_group .request_id in seq_id_to_seq_group :
615
+ del seq_id_to_seq_group [aborted_group .request_id ]
601
616
602
617
self ._free_seq_group_cross_attn_blocks (aborted_group )
603
618
0 commit comments