Skip to content

Commit 057daef

Browse files
authored
OpenAI Compatible Frontend (#116)
1 parent e867178 commit 057daef

20 files changed

+645
-170
lines changed

cacheflow/core/block_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
148148
# the sequences in the same group.
149149
blocks: Set[PhysicalTokenBlock] = set()
150150
for seq in seq_group.get_seqs():
151-
if seq.status == SequenceStatus.FINISHED:
151+
if SequenceStatus.is_finished(seq.status):
152152
continue
153153
block_table = self.block_tables[seq.seq_id]
154154
for block in block_table:
@@ -169,7 +169,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
169169
# CPU block -> GPU block.
170170
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
171171
for seq in seq_group.get_seqs():
172-
if seq.status == SequenceStatus.FINISHED:
172+
if SequenceStatus.is_finished(seq.status):
173173
continue
174174
new_block_table: BlockTable = []
175175
block_table = self.block_tables[seq.seq_id]
@@ -200,7 +200,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
200200
# GPU block -> CPU block.
201201
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
202202
for seq in seq_group.get_seqs():
203-
if seq.status == SequenceStatus.FINISHED:
203+
if SequenceStatus.is_finished(seq.status):
204204
continue
205205
new_block_table: BlockTable = []
206206
block_table = self.block_tables[seq.seq_id]

cacheflow/core/scheduler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,12 @@ def update(
292292
# Append a new token to the sequence.
293293
output = seq_outputs[seq.seq_id]
294294
seq.append_token_id(output.output_token, output.logprobs)
295+
# Return a shallow copy of the running queue to prevent the queue
296+
# from being modified by the caller.
295297
return self.running.copy()
296298

297-
def free_seq(self, seq: Sequence) -> None:
298-
seq.status = SequenceStatus.FINISHED
299+
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
300+
seq.status = finish_status
299301
self.block_manager.free(seq)
300302

301303
def free_finished_seq_groups(self) -> None:
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
2+
3+
import argparse
4+
from http import HTTPStatus
5+
import json
6+
import time
7+
from typing import AsyncGenerator, Dict, List, Optional
8+
9+
import fastapi
10+
from fastapi.exceptions import RequestValidationError
11+
from fastapi.middleware.cors import CORSMiddleware
12+
from fastapi.responses import StreamingResponse, JSONResponse
13+
import uvicorn
14+
15+
from cacheflow.outputs import RequestOutput
16+
from cacheflow.server.arg_utils import ServerArgs
17+
from cacheflow.server.async_llm_server import AsyncLLMServer
18+
from cacheflow.server.tokenizer_utils import get_tokenizer
19+
from cacheflow.logger import init_logger
20+
from cacheflow.sampling_params import SamplingParams
21+
from cacheflow.utils import random_uuid
22+
from cacheflow.entrypoints.openai.protocol import (
23+
CompletionRequest,
24+
CompletionResponse,
25+
CompletionResponseChoice,
26+
CompletionResponseStreamChoice,
27+
CompletionStreamResponse,
28+
ErrorResponse,
29+
LogProbs,
30+
ModelCard,
31+
ModelList,
32+
ModelPermission,
33+
UsageInfo,
34+
)
35+
36+
37+
logger = init_logger(__name__)
38+
served_model = None
39+
app = fastapi.FastAPI()
40+
41+
42+
def create_error_response(status_code: HTTPStatus,
43+
message: str) -> JSONResponse:
44+
return JSONResponse(
45+
ErrorResponse(message=message, type="invalid_request_error").dict(),
46+
status_code=status_code.value
47+
)
48+
49+
50+
@app.exception_handler(RequestValidationError)
51+
async def validation_exception_handler(request, exc):
52+
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
53+
54+
55+
async def check_model(request) -> Optional[JSONResponse]:
56+
if request.model == served_model:
57+
return
58+
ret = create_error_response(
59+
HTTPStatus.NOT_FOUND,
60+
f"The model `{request.model}` does not exist.",
61+
)
62+
return ret
63+
64+
65+
@app.get("/v1/models")
66+
async def show_available_models():
67+
"""Show available models. Right now we only have one model."""
68+
model_cards = [ModelCard(id=served_model, root=served_model,
69+
permission=[ModelPermission()])]
70+
return ModelList(data=model_cards)
71+
72+
73+
def create_logprobs(token_ids: List[int],
74+
id_logprobs: List[Dict[int, float]],
75+
initial_text_offset: int = 0) -> LogProbs:
76+
"""Create OpenAI-style logprobs."""
77+
logprobs = LogProbs()
78+
last_token_len = 0
79+
for token_id, id_logprob in zip(token_ids, id_logprobs):
80+
token = tokenizer.convert_ids_to_tokens(token_id)
81+
logprobs.tokens.append(token)
82+
logprobs.token_logprobs.append(id_logprob[token_id])
83+
if len(logprobs.text_offset) == 0:
84+
logprobs.text_offset.append(initial_text_offset)
85+
else:
86+
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
87+
last_token_len = len(token)
88+
89+
logprobs.top_logprobs.append(
90+
{tokenizer.convert_ids_to_tokens(i): p
91+
for i, p in id_logprob.items()})
92+
return logprobs
93+
94+
95+
@app.post("/v1/completions")
96+
async def create_completion(request: CompletionRequest):
97+
logger.info(f"Received completion request: {request}")
98+
99+
error_check_ret = await check_model(request)
100+
if error_check_ret is not None:
101+
return error_check_ret
102+
103+
if request.echo:
104+
# We do not support echo since the cacheflow server does not
105+
# currently support getting the logprobs of prompt tokens.
106+
return create_error_response(HTTPStatus.BAD_REQUEST,
107+
"echo is not currently supported")
108+
109+
if request.suffix is not None:
110+
# The language models we currently support do not support suffix.
111+
return create_error_response(HTTPStatus.BAD_REQUEST,
112+
"suffix is not currently supported")
113+
114+
if request.logit_bias is not None:
115+
# TODO: support logit_bias in cacheflow server.
116+
return create_error_response(HTTPStatus.BAD_REQUEST,
117+
"logit_bias is not currently supported")
118+
119+
model_name = request.model
120+
request_id = f"cmpl-{random_uuid()}"
121+
prompt = request.prompt
122+
created_time = int(time.time())
123+
try:
124+
sampling_params = SamplingParams(
125+
n=request.n,
126+
best_of=request.best_of,
127+
presence_penalty=request.presence_penalty,
128+
frequency_penalty=request.frequency_penalty,
129+
temperature=request.temperature,
130+
top_p=request.top_p,
131+
top_k=request.top_k,
132+
stop=request.stop,
133+
ignore_eos=request.ignore_eos,
134+
max_tokens=request.max_tokens,
135+
logprobs=request.logprobs,
136+
use_beam_search=request.use_beam_search,
137+
)
138+
except ValueError as e:
139+
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
140+
141+
result_generator = server.generate(prompt, sampling_params,
142+
request_id=request_id)
143+
144+
# Similar to the OpenAI API, when n != best_of, we do not stream the
145+
# results. In addition, we do not stream the results when use beam search.
146+
stream = (request.stream and
147+
(request.best_of is None or request.n == request.best_of) and
148+
not request.use_beam_search)
149+
150+
def create_stream_response_json(index: int,
151+
text: str,
152+
logprobs: Optional[LogProbs] = None,
153+
finish_reason: Optional[str] = None) -> str:
154+
choice_data = CompletionResponseStreamChoice(
155+
index=index,
156+
text=text,
157+
logprobs=logprobs,
158+
finish_reason=finish_reason,
159+
)
160+
response = CompletionStreamResponse(
161+
id=request_id,
162+
created=created_time,
163+
model=model_name,
164+
choices=[choice_data],
165+
)
166+
response_json = response.json(ensure_ascii=False)
167+
168+
return response_json
169+
170+
async def completion_stream_generator() -> AsyncGenerator[str, None]:
171+
previous_texts = [""] * request.n
172+
previous_num_tokens = [0] * request.n
173+
async for res in result_generator:
174+
res: RequestOutput
175+
for output in res.outputs:
176+
i = output.index
177+
delta_text = output.text[len(previous_texts[i]):]
178+
if request.logprobs is not None:
179+
logprobs = create_logprobs(
180+
output.token_ids[previous_num_tokens[i]:],
181+
output.logprobs[previous_num_tokens[i]:],
182+
len(previous_texts[i]))
183+
else:
184+
logprobs = None
185+
previous_texts[i] = output.text
186+
previous_num_tokens[i] = len(output.token_ids)
187+
response_json = create_stream_response_json(
188+
index=i,
189+
text=delta_text,
190+
logprobs=logprobs,
191+
)
192+
yield f"data: {response_json}\n\n"
193+
if output.finish_reason is not None:
194+
logprobs = LogProbs() if request.logprobs is not None else None
195+
response_json = create_stream_response_json(
196+
index=i,
197+
text="",
198+
logprobs=logprobs,
199+
finish_reason=output.finish_reason,
200+
)
201+
yield f"data: {response_json}\n\n"
202+
yield "data: [DONE]\n\n"
203+
204+
# Streaming response
205+
if stream:
206+
return StreamingResponse(completion_stream_generator(),
207+
media_type="text/event-stream")
208+
209+
# Non-streaming response
210+
final_res: RequestOutput = None
211+
async for res in result_generator:
212+
final_res = res
213+
assert final_res is not None
214+
choices = []
215+
for output in final_res.outputs:
216+
if request.logprobs is not None:
217+
logprobs = create_logprobs(output.token_ids, output.logprobs)
218+
else:
219+
logprobs = None
220+
choice_data = CompletionResponseChoice(
221+
index=output.index,
222+
text=output.text,
223+
logprobs=logprobs,
224+
finish_reason=output.finish_reason,
225+
)
226+
choices.append(choice_data)
227+
228+
num_prompt_tokens = len(final_res.prompt_token_ids)
229+
num_generated_tokens = sum(len(output.token_ids)
230+
for output in final_res.outputs)
231+
usage = UsageInfo(
232+
prompt_tokens=num_prompt_tokens,
233+
completion_tokens=num_generated_tokens,
234+
total_tokens=num_prompt_tokens + num_generated_tokens,
235+
)
236+
response = CompletionResponse(
237+
id=request_id,
238+
created=created_time,
239+
model=model_name,
240+
choices=choices,
241+
usage=usage,
242+
)
243+
244+
if request.stream:
245+
# When user requests streaming but we don't stream, we still need to
246+
# return a streaming response with a single event.
247+
response_json = response.json(ensure_ascii=False)
248+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
249+
yield f"data: {response_json}\n\n"
250+
yield "data: [DONE]\n\n"
251+
return StreamingResponse(fake_stream_generator(),
252+
media_type="text/event-stream")
253+
254+
return response
255+
256+
257+
if __name__ == "__main__":
258+
parser = argparse.ArgumentParser(
259+
description="CacheFlow OpenAI-Compatible RESTful API server."
260+
)
261+
parser.add_argument("--host", type=str, default="localhost", help="host name")
262+
parser.add_argument("--port", type=int, default=8000, help="port number")
263+
parser.add_argument(
264+
"--allow-credentials", action="store_true", help="allow credentials"
265+
)
266+
parser.add_argument(
267+
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
268+
)
269+
parser.add_argument(
270+
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
271+
)
272+
parser.add_argument(
273+
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
274+
)
275+
parser.add_argument("--served-model-name", type=str, default=None,
276+
help="The model name used in the API. If not specified, "
277+
"the model name will be the same as the "
278+
"huggingface name.")
279+
parser = ServerArgs.add_cli_args(parser)
280+
args = parser.parse_args()
281+
282+
app.add_middleware(
283+
CORSMiddleware,
284+
allow_origins=args.allowed_origins,
285+
allow_credentials=args.allow_credentials,
286+
allow_methods=args.allowed_methods,
287+
allow_headers=args.allowed_headers,
288+
)
289+
290+
logger.info(f"args: {args}")
291+
292+
served_model = args.served_model_name or args.model
293+
294+
server_args = ServerArgs.from_cli_args(args)
295+
server = AsyncLLMServer.from_server_args(server_args)
296+
297+
# A separate tokenizer to map token IDs to strings.
298+
tokenizer = get_tokenizer(args.model)
299+
300+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

0 commit comments

Comments
 (0)