@@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
501501
502502 for seq , _ in child_seqs :
503503 if seq_group .sampling_params .detokenize :
504- self .detokenizer .decode_sequence_inplace (
504+ new_char_count = self .detokenizer .decode_sequence_inplace (
505505 seq , seq_group .sampling_params )
506- self ._check_stop (seq , seq_group .sampling_params )
506+ else :
507+ new_char_count = 0
508+ self ._check_stop (seq , new_char_count , seq_group .sampling_params )
507509
508510 # Non-beam search case
509511 if not seq_group .sampling_params .use_beam_search :
@@ -798,56 +800,86 @@ def _get_stats(self,
798800 time_e2e_requests = time_e2e_requests ,
799801 )
800802
801- def _check_stop (self , seq : Sequence ,
803+ def _check_stop (self , seq : Sequence , new_char_count : int ,
802804 sampling_params : SamplingParams ) -> None :
803- """Stop the finished sequences."""
804- # Check if the sequence has reached max_model_len.
805- if seq .get_len () > self .scheduler_config .max_model_len :
806- seq .status = SequenceStatus .FINISHED_LENGTH_CAPPED
807- return
805+ """Stop the finished sequences.
808806
809- # Check if the sequence has reached max_tokens.
810- if seq .get_output_len () == sampling_params .max_tokens :
811- seq .status = SequenceStatus .FINISHED_LENGTH_CAPPED
812- return
807+ new_char_count is the number of chars added to the
808+ sequence's output text for the newly generated token
809+ """
813810
814811 # Check if the minimum number of tokens has been generated yet;
815812 # skip the stop string/token checks if not
816813 if seq .get_output_len () < sampling_params .min_tokens :
817814 return
818815
819- if sampling_params .detokenize :
820- for stop_str in sampling_params .stop :
821- if seq .output_text .endswith (stop_str ):
822- self ._finalize_sequence (seq , sampling_params , stop_str )
823- seq .status = SequenceStatus .FINISHED_STOPPED
824- seq .stop_reason = stop_str
825- return
816+ # Check if the sequence has generated the EOS token.
817+ if ((not sampling_params .ignore_eos )
818+ and seq .get_last_token_id () == seq .eos_token_id ):
819+ seq .status = SequenceStatus .FINISHED_STOPPED
820+ return
821+
822+ # Check if a stop token was encountered.
823+ # This assumes a single token produced per step.
826824 last_token_id = seq .get_last_token_id ()
827825 if last_token_id in sampling_params .stop_token_ids :
828- stop_str = self .get_tokenizer_for_seq (seq ).convert_ids_to_tokens (
829- last_token_id )
830- self ._finalize_sequence (seq , sampling_params , stop_str )
826+ if new_char_count and (
827+ not sampling_params .include_stop_str_in_output ):
828+ # Remove last token
829+ seq .output_text = seq .output_text [:- new_char_count ]
831830 seq .status = SequenceStatus .FINISHED_STOPPED
832831 seq .stop_reason = last_token_id
833832 return
834833
835- # Check if the sequence has generated the EOS token.
836- if ((not sampling_params .ignore_eos )
837- and seq .get_last_token_id () == seq .eos_token_id ):
834+ # Check if any stop strings are matched.
835+ stop_str = self ._check_stop_strings (seq , new_char_count ,
836+ sampling_params )
837+ if stop_str is not None :
838838 seq .status = SequenceStatus .FINISHED_STOPPED
839+ seq .stop_reason = stop_str
839840 return
840841
841- def _finalize_sequence (self , seq : Sequence ,
842- sampling_params : SamplingParams ,
843- stop_string : str ) -> None :
844- if sampling_params .include_stop_str_in_output :
842+ # Check if the sequence has reached max_model_len.
843+ if seq .get_len () > self .scheduler_config .max_model_len :
844+ seq .status = SequenceStatus .FINISHED_LENGTH_CAPPED
845845 return
846846
847- if stop_string and seq .output_text .endswith (stop_string ):
848- # Truncate the output text so that the stop string is
849- # not included in the output.
850- seq .output_text = seq .output_text [:- len (stop_string )]
847+ # Check if the sequence has reached max_tokens.
848+ if seq .get_output_len () == sampling_params .max_tokens :
849+ seq .status = SequenceStatus .FINISHED_LENGTH_CAPPED
850+ return
851+
852+ @staticmethod
853+ def _check_stop_strings (seq : Sequence , new_char_count : int ,
854+ sampling_params : SamplingParams ) -> Optional [str ]:
855+ """Check if any stop strings are matched and truncate sequence
856+ output text accordingly.
857+
858+ Returns the stop string if matched or else None.
859+ """
860+ if not new_char_count :
861+ return None
862+
863+ for stop_str in sampling_params .stop :
864+ stop_string_len = len (stop_str )
865+ # Avoid searching already-searched text.
866+ stop_index = seq .output_text .find (
867+ stop_str , - new_char_count - stop_string_len )
868+ if stop_index == - 1 :
869+ continue
870+
871+ if sampling_params .include_stop_str_in_output :
872+ # Truncate to end of stop string.
873+ stop_index += stop_string_len
874+ if stop_index >= len (seq .output_text ):
875+ # No truncation required.
876+ return stop_str
877+
878+ # Truncate the output text to either the beginning
879+ # or end of the stop string.
880+ seq .output_text = seq .output_text [:stop_index ]
881+ return stop_str
882+ return None
851883
852884 def add_lora (self , lora_request : LoRARequest ) -> bool :
853885 return self .model_executor .add_lora (lora_request )
0 commit comments