@@ -1471,6 +1471,9 @@ def add_request(request_id: str, engine, params, *args, **kwargs):
1471
1471
def finish_seq (self , seq : SequenceGroup ):
1472
1472
"""The sequence `seq` finishes, we should record the information.
1473
1473
"""
1474
+ # idempotent
1475
+ if seq .request_id not in self .to_be_finished :
1476
+ return
1474
1477
del self .to_be_finished [seq .request_id ]
1475
1478
self .finished_reqs [seq .request_id ] = seq
1476
1479
@@ -1529,34 +1532,30 @@ def add_request(request_id: str, engine, params, **kwargs):
1529
1532
def maybe_assemble_group (
1530
1533
self , seq_group : SequenceGroup ) -> Optional [SequenceGroup ]:
1531
1534
1532
- # in the streaming mode, we will return the assembled sequence for the
1533
- # last remaining sequence, and return None for the rest of sequences
1534
- if self .streaming :
1535
- last_remaining_id = list (self .to_be_finished )[- 1 ]
1536
- if seq_group .request_id == last_remaining_id :
1535
+ # in the streaming mode, we must return the assembled sequence
1536
+ # group while sequences are still processing, but only for one of
1537
+ # the remaining sequences
1538
+ if self .streaming and not seq_group .is_finished ():
1539
+ first_remaining_id = next (iter (self .to_be_finished ))
1540
+ if seq_group .request_id == first_remaining_id :
1537
1541
return self .assembled_seq_group
1538
1542
return None
1539
1543
1540
- # in the non-streaming mode, we will return the assembled sequence
1541
- # when the last sequences finishes, and then return None for the
1542
- # rest of the time
1543
- if (len (self .to_be_finished ) == 1
1544
- and seq_group .request_id in self .to_be_finished
1545
- and seq_group .is_finished ()):
1546
- assert self .assembled_seq_group is not None
1547
- params = self .assembled_seq_group .sampling_params
1548
- assert isinstance (params , SamplingParams )
1549
- if not self .output_produced :
1550
- self .output_produced = True
1551
- if params ._real_n is not None :
1552
- # Get the top-n sequences.
1553
- n = params ._real_n or params .n
1554
- seqs = self .assembled_seq_group .seqs
1555
- sorting_key = lambda seq : seq .get_cumulative_logprob ()
1556
- sorted_seqs = sorted (seqs , key = sorting_key , reverse = True )
1557
- top_n_seqs = sorted_seqs [:n ]
1558
- self .assembled_seq_group .seqs = top_n_seqs
1559
- return self .assembled_seq_group
1560
- if self .output_produced :
1561
- return None
1562
- return None
1544
+ # for non-streaming and when all streamed sequences are finished,
1545
+ # we will return the assembled sequence for the last finished sequence
1546
+ if len (self .to_be_finished ) > 0 or self .output_produced :
1547
+ return None
1548
+
1549
+ assert self .assembled_seq_group is not None
1550
+ params = self .assembled_seq_group .sampling_params
1551
+ assert isinstance (params , SamplingParams )
1552
+ self .output_produced = True
1553
+ if params ._real_n is not None :
1554
+ # Get the top-n sequences.
1555
+ n = params ._real_n or params .n
1556
+ seqs = self .assembled_seq_group .seqs
1557
+ sorting_key = lambda seq : seq .get_cumulative_logprob ()
1558
+ sorted_seqs = sorted (seqs , key = sorting_key , reverse = True )
1559
+ top_n_seqs = sorted_seqs [:n ]
1560
+ self .assembled_seq_group .seqs = top_n_seqs
1561
+ return self .assembled_seq_group
0 commit comments