@@ -908,13 +908,13 @@ def _check_stop(self, seq: Sequence,
908
908
"""Stop the finished sequences."""
909
909
for stop_str in sampling_params .stop :
910
910
if seq .output_text .endswith (stop_str ):
911
- if not sampling_params .include_stop_str_in_output :
912
- # Truncate the output text so that the stop string is
913
- # not included in the output.
914
- seq .output_text = seq .output_text [:- len (stop_str )]
911
+ self ._finalize_sequence (seq , sampling_params , stop_str )
915
912
seq .status = SequenceStatus .FINISHED_STOPPED
916
913
return
917
914
if seq .get_last_token_id () in sampling_params .stop_token_ids :
915
+ stop_str = self .get_tokenizer_for_seq (seq ).convert_ids_to_tokens (
916
+ seq .get_last_token_id ())
917
+ self ._finalize_sequence (seq , sampling_params , stop_str )
918
918
seq .status = SequenceStatus .FINISHED_STOPPED
919
919
return
920
920
@@ -934,6 +934,14 @@ def _check_stop(self, seq: Sequence,
934
934
seq .status = SequenceStatus .FINISHED_STOPPED
935
935
return
936
936
937
+ def _finalize_sequence (self , seq : Sequence ,
938
+ sampling_params : SamplingParams ,
939
+ stop_string : str ) -> None :
940
+ if not sampling_params .include_stop_str_in_output and stop_string :
941
+ # Truncate the output text so that the stop string is
942
+ # not included in the output.
943
+ seq .output_text = seq .output_text [:- len (stop_string )]
944
+
937
945
def add_lora (self , lora_request : LoRARequest ) -> bool :
938
946
assert lora_request .lora_int_id > 0 , "lora_id must be greater than 0."
939
947
return self ._run_workers (
0 commit comments