Skip to content

Commit a9cd354

Browse files
committed
refactor: try a different fix
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
1 parent 6e610a7 commit a9cd354

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

vllm/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def from_seq_group(
178178
if seq_group.request_id in seq_id_to_seq_group:
179179
group: SequenceGroupBase = seq_id_to_seq_group[
180180
seq_group.request_id]
181-
assembled_seq_group = group.maybe_assemble_group(seq_group)
182181
if finished:
183182
group.finish_seq(seq_group)
183+
assembled_seq_group = group.maybe_assemble_group(seq_group)
184184
if assembled_seq_group is None:
185185
return None
186186

vllm/sequence.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,9 @@ def add_request(request_id: str, engine, params, *args, **kwargs):
14711471
def finish_seq(self, seq: SequenceGroup):
14721472
"""The sequence `seq` finishes, we should record the information.
14731473
"""
1474+
# idempotent
1475+
if seq.request_id not in self.to_be_finished:
1476+
return
14741477
del self.to_be_finished[seq.request_id]
14751478
self.finished_reqs[seq.request_id] = seq
14761479

@@ -1529,34 +1532,30 @@ def add_request(request_id: str, engine, params, **kwargs):
15291532
def maybe_assemble_group(
15301533
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
15311534

1532-
# in the streaming mode, we will return the assembled sequence for the
1533-
# last remaining sequence, and return None for the rest of sequences
1534-
if self.streaming:
1535-
last_remaining_id = list(self.to_be_finished)[-1]
1536-
if seq_group.request_id == last_remaining_id:
1535+
# in the streaming mode, we must return the assembled sequence
1536+
# group while sequences are still processing, but only for one of
1537+
# the remaining sequences
1538+
if self.streaming and not seq_group.is_finished():
1539+
first_remaining_id = next(iter(self.to_be_finished))
1540+
if seq_group.request_id == first_remaining_id:
15371541
return self.assembled_seq_group
15381542
return None
15391543

1540-
# in the non-streaming mode, we will return the assembled sequence
1541-
# when the last sequences finishes, and then return None for the
1542-
# rest of the time
1543-
if (len(self.to_be_finished) == 1
1544-
and seq_group.request_id in self.to_be_finished
1545-
and seq_group.is_finished()):
1546-
assert self.assembled_seq_group is not None
1547-
params = self.assembled_seq_group.sampling_params
1548-
assert isinstance(params, SamplingParams)
1549-
if not self.output_produced:
1550-
self.output_produced = True
1551-
if params._real_n is not None:
1552-
# Get the top-n sequences.
1553-
n = params._real_n or params.n
1554-
seqs = self.assembled_seq_group.seqs
1555-
sorting_key = lambda seq: seq.get_cumulative_logprob()
1556-
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
1557-
top_n_seqs = sorted_seqs[:n]
1558-
self.assembled_seq_group.seqs = top_n_seqs
1559-
return self.assembled_seq_group
1560-
if self.output_produced:
1561-
return None
1562-
return None
1544+
# for non-streaming and when all streamed sequences are finished,
1545+
# we will return the assembled sequence for the last finished sequence
1546+
if len(self.to_be_finished) > 0 or self.output_produced:
1547+
return None
1548+
1549+
assert self.assembled_seq_group is not None
1550+
params = self.assembled_seq_group.sampling_params
1551+
assert isinstance(params, SamplingParams)
1552+
self.output_produced = True
1553+
if params._real_n is not None:
1554+
# Get the top-n sequences.
1555+
n = params._real_n or params.n
1556+
seqs = self.assembled_seq_group.seqs
1557+
sorting_key = lambda seq: seq.get_cumulative_logprob()
1558+
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
1559+
top_n_seqs = sorted_seqs[:n]
1560+
self.assembled_seq_group.seqs = top_n_seqs
1561+
return self.assembled_seq_group

0 commit comments

Comments
 (0)