|
| 1 | +# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py |
| 2 | + |
| 3 | +import argparse |
| 4 | +from http import HTTPStatus |
| 5 | +import json |
| 6 | +import time |
| 7 | +from typing import AsyncGenerator, Dict, List, Optional |
| 8 | + |
| 9 | +import fastapi |
| 10 | +from fastapi.exceptions import RequestValidationError |
| 11 | +from fastapi.middleware.cors import CORSMiddleware |
| 12 | +from fastapi.responses import StreamingResponse, JSONResponse |
| 13 | +import uvicorn |
| 14 | + |
| 15 | +from cacheflow.outputs import RequestOutput |
| 16 | +from cacheflow.server.arg_utils import ServerArgs |
| 17 | +from cacheflow.server.async_llm_server import AsyncLLMServer |
| 18 | +from cacheflow.server.tokenizer_utils import get_tokenizer |
| 19 | +from cacheflow.logger import init_logger |
| 20 | +from cacheflow.sampling_params import SamplingParams |
| 21 | +from cacheflow.utils import random_uuid |
| 22 | +from cacheflow.entrypoints.openai.protocol import ( |
| 23 | + CompletionRequest, |
| 24 | + CompletionResponse, |
| 25 | + CompletionResponseChoice, |
| 26 | + CompletionResponseStreamChoice, |
| 27 | + CompletionStreamResponse, |
| 28 | + ErrorResponse, |
| 29 | + LogProbs, |
| 30 | + ModelCard, |
| 31 | + ModelList, |
| 32 | + ModelPermission, |
| 33 | + UsageInfo, |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +logger = init_logger(__name__) |
| 38 | +served_model = None |
| 39 | +app = fastapi.FastAPI() |
| 40 | + |
| 41 | + |
| 42 | +def create_error_response(status_code: HTTPStatus, |
| 43 | + message: str) -> JSONResponse: |
| 44 | + return JSONResponse( |
| 45 | + ErrorResponse(message=message, type="invalid_request_error").dict(), |
| 46 | + status_code=status_code.value |
| 47 | + ) |
| 48 | + |
| 49 | + |
| 50 | +@app.exception_handler(RequestValidationError) |
| 51 | +async def validation_exception_handler(request, exc): |
| 52 | + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) |
| 53 | + |
| 54 | + |
| 55 | +async def check_model(request) -> Optional[JSONResponse]: |
| 56 | + if request.model == served_model: |
| 57 | + return |
| 58 | + ret = create_error_response( |
| 59 | + HTTPStatus.NOT_FOUND, |
| 60 | + f"The model `{request.model}` does not exist.", |
| 61 | + ) |
| 62 | + return ret |
| 63 | + |
| 64 | + |
| 65 | +@app.get("/v1/models") |
| 66 | +async def show_available_models(): |
| 67 | + """Show available models. Right now we only have one model.""" |
| 68 | + model_cards = [ModelCard(id=served_model, root=served_model, |
| 69 | + permission=[ModelPermission()])] |
| 70 | + return ModelList(data=model_cards) |
| 71 | + |
| 72 | + |
| 73 | +def create_logprobs(token_ids: List[int], |
| 74 | + id_logprobs: List[Dict[int, float]], |
| 75 | + initial_text_offset: int = 0) -> LogProbs: |
| 76 | + """Create OpenAI-style logprobs.""" |
| 77 | + logprobs = LogProbs() |
| 78 | + last_token_len = 0 |
| 79 | + for token_id, id_logprob in zip(token_ids, id_logprobs): |
| 80 | + token = tokenizer.convert_ids_to_tokens(token_id) |
| 81 | + logprobs.tokens.append(token) |
| 82 | + logprobs.token_logprobs.append(id_logprob[token_id]) |
| 83 | + if len(logprobs.text_offset) == 0: |
| 84 | + logprobs.text_offset.append(initial_text_offset) |
| 85 | + else: |
| 86 | + logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) |
| 87 | + last_token_len = len(token) |
| 88 | + |
| 89 | + logprobs.top_logprobs.append( |
| 90 | + {tokenizer.convert_ids_to_tokens(i): p |
| 91 | + for i, p in id_logprob.items()}) |
| 92 | + return logprobs |
| 93 | + |
| 94 | + |
| 95 | +@app.post("/v1/completions") |
| 96 | +async def create_completion(request: CompletionRequest): |
| 97 | + logger.info(f"Received completion request: {request}") |
| 98 | + |
| 99 | + error_check_ret = await check_model(request) |
| 100 | + if error_check_ret is not None: |
| 101 | + return error_check_ret |
| 102 | + |
| 103 | + if request.echo: |
| 104 | + # We do not support echo since the cacheflow server does not |
| 105 | + # currently support getting the logprobs of prompt tokens. |
| 106 | + return create_error_response(HTTPStatus.BAD_REQUEST, |
| 107 | + "echo is not currently supported") |
| 108 | + |
| 109 | + if request.suffix is not None: |
| 110 | + # The language models we currently support do not support suffix. |
| 111 | + return create_error_response(HTTPStatus.BAD_REQUEST, |
| 112 | + "suffix is not currently supported") |
| 113 | + |
| 114 | + if request.logit_bias is not None: |
| 115 | + # TODO: support logit_bias in cacheflow server. |
| 116 | + return create_error_response(HTTPStatus.BAD_REQUEST, |
| 117 | + "logit_bias is not currently supported") |
| 118 | + |
| 119 | + model_name = request.model |
| 120 | + request_id = f"cmpl-{random_uuid()}" |
| 121 | + prompt = request.prompt |
| 122 | + created_time = int(time.time()) |
| 123 | + try: |
| 124 | + sampling_params = SamplingParams( |
| 125 | + n=request.n, |
| 126 | + best_of=request.best_of, |
| 127 | + presence_penalty=request.presence_penalty, |
| 128 | + frequency_penalty=request.frequency_penalty, |
| 129 | + temperature=request.temperature, |
| 130 | + top_p=request.top_p, |
| 131 | + top_k=request.top_k, |
| 132 | + stop=request.stop, |
| 133 | + ignore_eos=request.ignore_eos, |
| 134 | + max_tokens=request.max_tokens, |
| 135 | + logprobs=request.logprobs, |
| 136 | + use_beam_search=request.use_beam_search, |
| 137 | + ) |
| 138 | + except ValueError as e: |
| 139 | + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) |
| 140 | + |
| 141 | + result_generator = server.generate(prompt, sampling_params, |
| 142 | + request_id=request_id) |
| 143 | + |
| 144 | + # Similar to the OpenAI API, when n != best_of, we do not stream the |
| 145 | + # results. In addition, we do not stream the results when use beam search. |
| 146 | + stream = (request.stream and |
| 147 | + (request.best_of is None or request.n == request.best_of) and |
| 148 | + not request.use_beam_search) |
| 149 | + |
| 150 | + def create_stream_response_json(index: int, |
| 151 | + text: str, |
| 152 | + logprobs: Optional[LogProbs] = None, |
| 153 | + finish_reason: Optional[str] = None) -> str: |
| 154 | + choice_data = CompletionResponseStreamChoice( |
| 155 | + index=index, |
| 156 | + text=text, |
| 157 | + logprobs=logprobs, |
| 158 | + finish_reason=finish_reason, |
| 159 | + ) |
| 160 | + response = CompletionStreamResponse( |
| 161 | + id=request_id, |
| 162 | + created=created_time, |
| 163 | + model=model_name, |
| 164 | + choices=[choice_data], |
| 165 | + ) |
| 166 | + response_json = response.json(ensure_ascii=False) |
| 167 | + |
| 168 | + return response_json |
| 169 | + |
| 170 | + async def completion_stream_generator() -> AsyncGenerator[str, None]: |
| 171 | + previous_texts = [""] * request.n |
| 172 | + previous_num_tokens = [0] * request.n |
| 173 | + async for res in result_generator: |
| 174 | + res: RequestOutput |
| 175 | + for output in res.outputs: |
| 176 | + i = output.index |
| 177 | + delta_text = output.text[len(previous_texts[i]):] |
| 178 | + if request.logprobs is not None: |
| 179 | + logprobs = create_logprobs( |
| 180 | + output.token_ids[previous_num_tokens[i]:], |
| 181 | + output.logprobs[previous_num_tokens[i]:], |
| 182 | + len(previous_texts[i])) |
| 183 | + else: |
| 184 | + logprobs = None |
| 185 | + previous_texts[i] = output.text |
| 186 | + previous_num_tokens[i] = len(output.token_ids) |
| 187 | + response_json = create_stream_response_json( |
| 188 | + index=i, |
| 189 | + text=delta_text, |
| 190 | + logprobs=logprobs, |
| 191 | + ) |
| 192 | + yield f"data: {response_json}\n\n" |
| 193 | + if output.finish_reason is not None: |
| 194 | + logprobs = LogProbs() if request.logprobs is not None else None |
| 195 | + response_json = create_stream_response_json( |
| 196 | + index=i, |
| 197 | + text="", |
| 198 | + logprobs=logprobs, |
| 199 | + finish_reason=output.finish_reason, |
| 200 | + ) |
| 201 | + yield f"data: {response_json}\n\n" |
| 202 | + yield "data: [DONE]\n\n" |
| 203 | + |
| 204 | + # Streaming response |
| 205 | + if stream: |
| 206 | + return StreamingResponse(completion_stream_generator(), |
| 207 | + media_type="text/event-stream") |
| 208 | + |
| 209 | + # Non-streaming response |
| 210 | + final_res: RequestOutput = None |
| 211 | + async for res in result_generator: |
| 212 | + final_res = res |
| 213 | + assert final_res is not None |
| 214 | + choices = [] |
| 215 | + for output in final_res.outputs: |
| 216 | + if request.logprobs is not None: |
| 217 | + logprobs = create_logprobs(output.token_ids, output.logprobs) |
| 218 | + else: |
| 219 | + logprobs = None |
| 220 | + choice_data = CompletionResponseChoice( |
| 221 | + index=output.index, |
| 222 | + text=output.text, |
| 223 | + logprobs=logprobs, |
| 224 | + finish_reason=output.finish_reason, |
| 225 | + ) |
| 226 | + choices.append(choice_data) |
| 227 | + |
| 228 | + num_prompt_tokens = len(final_res.prompt_token_ids) |
| 229 | + num_generated_tokens = sum(len(output.token_ids) |
| 230 | + for output in final_res.outputs) |
| 231 | + usage = UsageInfo( |
| 232 | + prompt_tokens=num_prompt_tokens, |
| 233 | + completion_tokens=num_generated_tokens, |
| 234 | + total_tokens=num_prompt_tokens + num_generated_tokens, |
| 235 | + ) |
| 236 | + response = CompletionResponse( |
| 237 | + id=request_id, |
| 238 | + created=created_time, |
| 239 | + model=model_name, |
| 240 | + choices=choices, |
| 241 | + usage=usage, |
| 242 | + ) |
| 243 | + |
| 244 | + if request.stream: |
| 245 | + # When user requests streaming but we don't stream, we still need to |
| 246 | + # return a streaming response with a single event. |
| 247 | + response_json = response.json(ensure_ascii=False) |
| 248 | + async def fake_stream_generator() -> AsyncGenerator[str, None]: |
| 249 | + yield f"data: {response_json}\n\n" |
| 250 | + yield "data: [DONE]\n\n" |
| 251 | + return StreamingResponse(fake_stream_generator(), |
| 252 | + media_type="text/event-stream") |
| 253 | + |
| 254 | + return response |
| 255 | + |
| 256 | + |
| 257 | +if __name__ == "__main__": |
| 258 | + parser = argparse.ArgumentParser( |
| 259 | + description="CacheFlow OpenAI-Compatible RESTful API server." |
| 260 | + ) |
| 261 | + parser.add_argument("--host", type=str, default="localhost", help="host name") |
| 262 | + parser.add_argument("--port", type=int, default=8000, help="port number") |
| 263 | + parser.add_argument( |
| 264 | + "--allow-credentials", action="store_true", help="allow credentials" |
| 265 | + ) |
| 266 | + parser.add_argument( |
| 267 | + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" |
| 268 | + ) |
| 269 | + parser.add_argument( |
| 270 | + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" |
| 271 | + ) |
| 272 | + parser.add_argument( |
| 273 | + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" |
| 274 | + ) |
| 275 | + parser.add_argument("--served-model-name", type=str, default=None, |
| 276 | + help="The model name used in the API. If not specified, " |
| 277 | + "the model name will be the same as the " |
| 278 | + "huggingface name.") |
| 279 | + parser = ServerArgs.add_cli_args(parser) |
| 280 | + args = parser.parse_args() |
| 281 | + |
| 282 | + app.add_middleware( |
| 283 | + CORSMiddleware, |
| 284 | + allow_origins=args.allowed_origins, |
| 285 | + allow_credentials=args.allow_credentials, |
| 286 | + allow_methods=args.allowed_methods, |
| 287 | + allow_headers=args.allowed_headers, |
| 288 | + ) |
| 289 | + |
| 290 | + logger.info(f"args: {args}") |
| 291 | + |
| 292 | + served_model = args.served_model_name or args.model |
| 293 | + |
| 294 | + server_args = ServerArgs.from_cli_args(args) |
| 295 | + server = AsyncLLMServer.from_server_args(server_args) |
| 296 | + |
| 297 | + # A separate tokenizer to map token IDs to strings. |
| 298 | + tokenizer = get_tokenizer(args.model) |
| 299 | + |
| 300 | + uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
0 commit comments