Skip to content

Commit

Permalink
clean api code, remove redundant background task. (vllm-project#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Sep 21, 2023
1 parent 1ac4ccf commit 2d1e86f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 27 deletions.
10 changes: 2 additions & 8 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from typing import AsyncGenerator

from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn

Expand Down Expand Up @@ -44,14 +44,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")

async def abort_request() -> None:
await engine.abort(request_id)

if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
return StreamingResponse(stream_results())

# Non-streaming case
final_output = None
Expand Down
24 changes: 5 additions & 19 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)

async def abort_request() -> None:
await engine.abort(request_id)

def create_stream_response_json(
index: int,
text: str,
Expand Down Expand Up @@ -291,19 +288,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
Expand Down Expand Up @@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)

async def abort_request() -> None:
await engine.abort(request_id)

def create_stream_response_json(
index: int,
text: str,
Expand Down Expand Up @@ -510,19 +500,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
Expand Down

0 comments on commit 2d1e86f

Please sign in to comment.