7
7
from typing import AsyncGenerator , Dict , List , Optional
8
8
9
9
import fastapi
10
+ from fastapi import BackgroundTasks , Request
10
11
from fastapi .exceptions import RequestValidationError
11
12
from fastapi .middleware .cors import CORSMiddleware
12
13
from fastapi .responses import StreamingResponse , JSONResponse
13
14
import uvicorn
14
15
15
16
from cacheflow .outputs import RequestOutput
16
- from cacheflow .server .arg_utils import ServerArgs
17
+ from cacheflow .server .arg_utils import AsyncServerArgs
17
18
from cacheflow .server .async_llm_server import AsyncLLMServer
18
19
from cacheflow .server .tokenizer_utils import get_tokenizer
19
20
from cacheflow .logger import init_logger
33
34
UsageInfo ,
34
35
)
35
36
37
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
36
38
37
39
logger = init_logger (__name__ )
38
40
served_model = None
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
93
95
94
96
95
97
@app .post ("/v1/completions" )
96
- async def create_completion (request : CompletionRequest ):
98
+ async def create_completion (raw_request : Request ):
99
+ request = CompletionRequest (** await raw_request .json ())
97
100
logger .info (f"Received completion request: { request } " )
98
101
99
102
error_check_ret = await check_model (request )
@@ -139,14 +142,17 @@ async def create_completion(request: CompletionRequest):
139
142
return create_error_response (HTTPStatus .BAD_REQUEST , str (e ))
140
143
141
144
result_generator = server .generate (prompt , sampling_params ,
142
- request_id = request_id )
145
+ request_id )
143
146
144
147
# Similar to the OpenAI API, when n != best_of, we do not stream the
145
148
# results. In addition, we do not stream the results when use beam search.
146
149
stream = (request .stream and
147
150
(request .best_of is None or request .n == request .best_of ) and
148
151
not request .use_beam_search )
149
152
153
+ async def abort_request () -> None :
154
+ await server .abort (request_id )
155
+
150
156
def create_stream_response_json (index : int ,
151
157
text : str ,
152
158
logprobs : Optional [LogProbs ] = None ,
@@ -203,12 +209,21 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
203
209
204
210
# Streaming response
205
211
if stream :
212
+ background_tasks = BackgroundTasks ()
213
+ # Abort the request if the client disconnects.
214
+ background_tasks .add_task (abort_request )
206
215
return StreamingResponse (completion_stream_generator (),
207
- media_type = "text/event-stream" )
216
+ media_type = "text/event-stream" ,
217
+ background = background_tasks )
208
218
209
219
# Non-streaming response
210
220
final_res : RequestOutput = None
211
221
async for res in result_generator :
222
+ if await raw_request .is_disconnected ():
223
+ # Abort the request if the client disconnects.
224
+ await server .abort (request_id )
225
+ return create_error_response (HTTPStatus .BAD_REQUEST ,
226
+ "Client disconnected" )
212
227
final_res = res
213
228
assert final_res is not None
214
229
choices = []
@@ -276,7 +291,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
276
291
help = "The model name used in the API. If not specified, "
277
292
"the model name will be the same as the "
278
293
"huggingface name." )
279
- parser = ServerArgs .add_cli_args (parser )
294
+ parser = AsyncServerArgs .add_cli_args (parser )
280
295
args = parser .parse_args ()
281
296
282
297
app .add_middleware (
@@ -291,10 +306,11 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
291
306
292
307
served_model = args .served_model_name or args .model
293
308
294
- server_args = ServerArgs .from_cli_args (args )
309
+ server_args = AsyncServerArgs .from_cli_args (args )
295
310
server = AsyncLLMServer .from_server_args (server_args )
296
311
297
312
# A separate tokenizer to map token IDs to strings.
298
313
tokenizer = get_tokenizer (args .model )
299
314
300
- uvicorn .run (app , host = args .host , port = args .port , log_level = "info" )
315
+ uvicorn .run (app , host = args .host , port = args .port , log_level = "info" ,
316
+ timeout_keep_alive = TIMEOUT_KEEP_ALIVE )
0 commit comments