Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand All @@ -17,17 +18,27 @@
engine = None


def get_text_outputs(request_output: RequestOutput, return_prompt: bool):
if return_prompt:
prompt = request_output.prompt
return [prompt + output.text for output in request_output.outputs]
else:
return [output.text for output in request_output.outputs]


@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- return_prompt: whether to return the prompt with the results or not.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
return_prompt = request_dict.pop("return_prompt", False)
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
Expand All @@ -37,11 +48,7 @@ async def generate(request: Request) -> Response:
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
ret = {"text": get_text_outputs(request_output, return_prompt)}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stream:
Expand All @@ -57,9 +64,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = request_output

assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
ret = {"text": get_text_outputs(final_output, return_prompt)}
return JSONResponse(ret)


Expand Down