Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent returning partial stop string in vllm worker #2780

Merged
merged 1 commit into from
Dec 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prevent returning partial stop string in vllm worker
  • Loading branch information
pandada8 committed Dec 6, 2023
commit 35d5d148a7bf298857304f84b8598ee5d2ea30d7
13 changes: 11 additions & 2 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger,
worker_id,
)
from fastchat.utils import get_context_length
from fastchat.utils import get_context_length, is_partial_stop


app = FastAPI()
Expand Down Expand Up @@ -119,7 +119,12 @@ async def generate_stream(self, params):
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet

partial_stop = any(is_partial_stop(text_outputs, i) for i in stop)
# prevent yielding partial stop sequence
if partial_stop:
continue

prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
Expand All @@ -139,6 +144,10 @@ async def generate_stream(self, params):
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs],
}
# Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response.
# This aligns with the behavior of model_worker.
if request_output.finished:
yield (json.dumps(ret | {"finish_reason": None}) + "\0").encode()
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
Expand Down