Skip to content

Commit

Permalink
Prevent returning partial stop string in vllm worker (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
pandada8 authored Dec 9, 2023
1 parent c842764 commit 173f4de
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger,
worker_id,
)
from fastchat.utils import get_context_length
from fastchat.utils import get_context_length, is_partial_stop


app = FastAPI()
Expand Down Expand Up @@ -119,7 +119,12 @@ async def generate_stream(self, params):
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet

partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
# prevent yielding partial stop sequence
if partial_stop:
continue

prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
Expand All @@ -139,6 +144,10 @@ async def generate_stream(self, params):
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs],
}
# Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response.
# This aligns with the behavior of model_worker.
if request_output.finished:
yield (json.dumps(ret | {"finish_reason": None}) + "\0").encode()
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
Expand Down

0 comments on commit 173f4de

Please sign in to comment.