Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added echo function to OpenAI API server. #1504

Merged
merged 9 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Updated to address revision comments.
  • Loading branch information
wanmok committed Nov 25, 2023
commit 7690b6d2c8683311d42df024283a5c250dfa670f
205 changes: 85 additions & 120 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@
engine = None


def create_error_response(
status_code: HTTPStatus, message: str
) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value,
)
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)


@app.exception_handler(RequestValidationError)
Expand All @@ -80,8 +78,7 @@ async def get_gen_prompt(request) -> str:
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`"
)
"Please upgrade fastchat to use: `$ pip install -U fschat`")

conv = get_conversation_template(request.model)
conv = Conversation(
Expand Down Expand Up @@ -122,14 +119,13 @@ async def get_gen_prompt(request) -> str:
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert not (prompt is None and prompt_ids is None) and not (
prompt is not None and prompt_ids is not None
), "Either prompt or prompt_ids should be provided."
input_ids = (
prompt_ids if prompt_ids is not None else tokenizer(prompt).input_ids
)
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
input_ids = prompt_ids if prompt_ids is not None else tokenizer(
prompt).input_ids
token_num = len(input_ids)

if request.max_tokens is None:
Expand Down Expand Up @@ -157,9 +153,9 @@ async def health() -> Response:
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(
id=served_model, root=served_model, permission=[ModelPermission()]
)
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)

Expand Down Expand Up @@ -203,9 +199,8 @@ def create_logprobs(


@app.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest, raw_request: Request
):
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.

See https://platform.openai.com/docs/api-reference/chat/create
Expand All @@ -223,9 +218,8 @@ async def create_chat_completion(

if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(
HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
Expand Down Expand Up @@ -256,9 +250,8 @@ async def create_chat_completion(
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = engine.generate(
prompt, sampling_params, request_id, token_ids
)
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)

def create_stream_response_json(
index: int,
Expand Down Expand Up @@ -292,9 +285,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id, choices=[choice_data], model=model_name
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"

Expand All @@ -304,7 +297,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]) :]
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens
Expand All @@ -331,19 +324,17 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if request.stream:
return StreamingResponse(
completion_stream_generator(), media_type="text/event-stream"
)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(
HTTPStatus.BAD_REQUEST, "Client disconnected"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
Expand All @@ -357,8 +348,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs
)
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
Expand All @@ -381,9 +371,8 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(
fake_stream_generator(), media_type="text/event-stream"
)
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")

return response

Expand Down Expand Up @@ -411,25 +400,22 @@ async def create_completion(request: CompletionRequest, raw_request: Request):

if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(
HTTPStatus.BAD_REQUEST, "suffix is not currently supported"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")

if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(
HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")

model_name = request.model
request_id = f"cmpl-{random_uuid()}"

use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(
HTTPStatus.BAD_REQUEST, "please provide at least one prompt"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
Expand All @@ -439,8 +425,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported",
)
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
Expand Down Expand Up @@ -478,21 +463,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

if use_token_ids:
result_generator = engine.generate(
None, sampling_params, request_id, prompt_token_ids=prompt
)
result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(
prompt, sampling_params, request_id, token_ids
)
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)

# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (
request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search
)
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)

def create_stream_response_json(
index: int,
Expand Down Expand Up @@ -561,9 +544,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (
LogProbs() if request.logprobs is not None else None
)
logprobs = (LogProbs()
if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
Expand All @@ -583,21 +565,18 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if stream:
return StreamingResponse(
completion_stream_generator(), media_type="text/event-stream"
)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(
HTTPStatus.BAD_REQUEST, "Client disconnected"
)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res

assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
Expand Down Expand Up @@ -637,8 +616,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs
)
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
Expand All @@ -661,48 +639,38 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(
fake_stream_generator(), media_type="text/event-stream"
)
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")

return response


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
)
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins",
)
parser.add_argument(
"--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods",
)
parser.add_argument(
"--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers",
)
parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.",
)
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -731,13 +699,10 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code,
)
trust_remote_code=engine_model_config.trust_remote_code)

uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from pydantic import BaseModel, Field

from vllm.sequence import PromptLogprobs
from vllm.utils import random_uuid


Expand Down
Loading