Skip to content

Commit 1a956e1

Browse files
authored
Fix various issues of async servers (#135)
1 parent 8274ca2 commit 1a956e1

File tree

11 files changed

+289
-121
lines changed

11 files changed

+289
-121
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import argparse
2+
import json
3+
import threading
4+
import time
5+
6+
import requests
7+
8+
9+
def main(args: argparse.Namespace):
10+
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
11+
for i in range(args.n_threads)]
12+
13+
headers = {"User-Agent": "CacheFlow Benchmark Client"}
14+
ploads = [{
15+
"prompt": p,
16+
"max_tokens": args.max_tokens,
17+
"temperature": 0.0,
18+
"ignore_eos": True,
19+
} for p in prompts]
20+
21+
def send_request(results, i):
22+
response = requests.post(args.api_url, headers=headers,
23+
json=ploads[i], stream=True)
24+
results[i] = response
25+
26+
# use args.n_threads to prompt the backend
27+
tik = time.time()
28+
threads = []
29+
results = [None] * args.n_threads
30+
for i in range(args.n_threads):
31+
t = threading.Thread(target=send_request, args=(results, i))
32+
t.start()
33+
threads.append(t)
34+
35+
for t in threads:
36+
t.join()
37+
38+
print(f"Time (POST): {time.time() - tik} s")
39+
n_words = 0
40+
41+
for i, response in enumerate(results):
42+
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
43+
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
44+
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
45+
46+
time_seconds = time.time() - tik
47+
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
48+
f"throughput: {n_words / time_seconds} words/s.")
49+
50+
51+
if __name__ == "__main__":
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
54+
parser.add_argument("--max-tokens", type=int, default=128)
55+
parser.add_argument("--n-threads", type=int, default=128)
56+
args = parser.parse_args()
57+
58+
main(args)

cacheflow/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ def __init__(
116116
self,
117117
pipeline_parallel_size: int,
118118
tensor_parallel_size: int,
119-
use_ray: bool,
119+
worker_use_ray: bool,
120120
) -> None:
121121
self.pipeline_parallel_size = pipeline_parallel_size
122122
self.tensor_parallel_size = tensor_parallel_size
123-
self.use_ray = use_ray
123+
self.worker_use_ray = worker_use_ray
124124

125125
self.world_size = pipeline_parallel_size * tensor_parallel_size
126126
if self.world_size > 1:
127-
self.use_ray = True
127+
self.worker_use_ray = True
128128
self._verify_args()
129129

130130
def _verify_args(self) -> None:

cacheflow/core/block_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
148148
# the sequences in the same group.
149149
blocks: Set[PhysicalTokenBlock] = set()
150150
for seq in seq_group.get_seqs():
151-
if SequenceStatus.is_finished(seq.status):
151+
if seq.is_finished():
152152
continue
153153
block_table = self.block_tables[seq.seq_id]
154154
for block in block_table:
@@ -169,7 +169,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
169169
# CPU block -> GPU block.
170170
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
171171
for seq in seq_group.get_seqs():
172-
if SequenceStatus.is_finished(seq.status):
172+
if seq.is_finished():
173173
continue
174174
new_block_table: BlockTable = []
175175
block_table = self.block_tables[seq.seq_id]
@@ -200,7 +200,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
200200
# GPU block -> CPU block.
201201
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
202202
for seq in seq_group.get_seqs():
203-
if SequenceStatus.is_finished(seq.status):
203+
if seq.is_finished():
204204
continue
205205
new_block_table: BlockTable = []
206206
block_table = self.block_tables[seq.seq_id]
@@ -231,6 +231,9 @@ def _free_block_table(self, block_table: BlockTable) -> None:
231231
self.cpu_allocator.free(block)
232232

233233
def free(self, seq: Sequence) -> None:
234+
if seq.seq_id not in self.block_tables:
235+
# Already freed or haven't been scheduled yet.
236+
return
234237
block_table = self.block_tables[seq.seq_id]
235238
self._free_block_table(block_table)
236239
del self.block_tables[seq.seq_id]

cacheflow/core/scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
logger = init_logger(__name__)
1414

15-
_LOGGING_INTERVAL_SEC = 10
15+
_LOGGING_INTERVAL_SEC = 5
1616

1717

1818
class PreemptionMode(enum.Enum):
@@ -84,6 +84,18 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None:
8484
# Add sequence groups to the waiting queue.
8585
self.waiting.append(seq_group)
8686

87+
def abort_seq_group(self, request_id: str) -> None:
88+
for state_queue in [self.waiting, self.running, self.swapped]:
89+
for seq_group in state_queue:
90+
if seq_group.request_id == request_id:
91+
# Remove the sequence group from the state queue.
92+
state_queue.remove(seq_group)
93+
for seq in seq_group.seqs:
94+
if seq.is_finished():
95+
continue
96+
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
97+
return
98+
8799
def has_unfinished_seqs(self) -> bool:
88100
return self.waiting or self.running or self.swapped
89101

cacheflow/entrypoints/openai/openai_frontend.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from typing import AsyncGenerator, Dict, List, Optional
88

99
import fastapi
10+
from fastapi import BackgroundTasks, Request
1011
from fastapi.exceptions import RequestValidationError
1112
from fastapi.middleware.cors import CORSMiddleware
1213
from fastapi.responses import StreamingResponse, JSONResponse
1314
import uvicorn
1415

1516
from cacheflow.outputs import RequestOutput
16-
from cacheflow.server.arg_utils import ServerArgs
17+
from cacheflow.server.arg_utils import AsyncServerArgs
1718
from cacheflow.server.async_llm_server import AsyncLLMServer
1819
from cacheflow.server.tokenizer_utils import get_tokenizer
1920
from cacheflow.logger import init_logger
@@ -33,6 +34,7 @@
3334
UsageInfo,
3435
)
3536

37+
TIMEOUT_KEEP_ALIVE = 5 # seconds
3638

3739
logger = init_logger(__name__)
3840
served_model = None
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
9395

9496

9597
@app.post("/v1/completions")
96-
async def create_completion(request: CompletionRequest):
98+
async def create_completion(raw_request: Request):
99+
request = CompletionRequest(**await raw_request.json())
97100
logger.info(f"Received completion request: {request}")
98101

99102
error_check_ret = await check_model(request)
@@ -139,14 +142,17 @@ async def create_completion(request: CompletionRequest):
139142
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
140143

141144
result_generator = server.generate(prompt, sampling_params,
142-
request_id=request_id)
145+
request_id)
143146

144147
# Similar to the OpenAI API, when n != best_of, we do not stream the
145148
# results. In addition, we do not stream the results when use beam search.
146149
stream = (request.stream and
147150
(request.best_of is None or request.n == request.best_of) and
148151
not request.use_beam_search)
149152

153+
async def abort_request() -> None:
154+
await server.abort(request_id)
155+
150156
def create_stream_response_json(index: int,
151157
text: str,
152158
logprobs: Optional[LogProbs] = None,
@@ -203,12 +209,21 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
203209

204210
# Streaming response
205211
if stream:
212+
background_tasks = BackgroundTasks()
213+
# Abort the request if the client disconnects.
214+
background_tasks.add_task(abort_request)
206215
return StreamingResponse(completion_stream_generator(),
207-
media_type="text/event-stream")
216+
media_type="text/event-stream",
217+
background=background_tasks)
208218

209219
# Non-streaming response
210220
final_res: RequestOutput = None
211221
async for res in result_generator:
222+
if await raw_request.is_disconnected():
223+
# Abort the request if the client disconnects.
224+
await server.abort(request_id)
225+
return create_error_response(HTTPStatus.BAD_REQUEST,
226+
"Client disconnected")
212227
final_res = res
213228
assert final_res is not None
214229
choices = []
@@ -276,7 +291,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
276291
help="The model name used in the API. If not specified, "
277292
"the model name will be the same as the "
278293
"huggingface name.")
279-
parser = ServerArgs.add_cli_args(parser)
294+
parser = AsyncServerArgs.add_cli_args(parser)
280295
args = parser.parse_args()
281296

282297
app.add_middleware(
@@ -291,10 +306,11 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
291306

292307
served_model = args.served_model_name or args.model
293308

294-
server_args = ServerArgs.from_cli_args(args)
309+
server_args = AsyncServerArgs.from_cli_args(args)
295310
server = AsyncLLMServer.from_server_args(server_args)
296311

297312
# A separate tokenizer to map token IDs to strings.
298313
tokenizer = get_tokenizer(args.model)
299314

300-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
315+
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
316+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

cacheflow/entrypoints/simple_fastapi_frontend.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
import json
33
from typing import AsyncGenerator
44

5-
from fastapi import FastAPI, Request
5+
from fastapi import BackgroundTasks, FastAPI, Request
66
from fastapi.responses import StreamingResponse
77
import uvicorn
88

99
from cacheflow.sampling_params import SamplingParams
10-
from cacheflow.server.arg_utils import ServerArgs
10+
from cacheflow.server.arg_utils import AsyncServerArgs
1111
from cacheflow.server.async_llm_server import AsyncLLMServer
12-
from cacheflow.server.ray_utils import initialize_cluster
12+
from cacheflow.utils import random_uuid
1313

14+
TIMEOUT_KEEP_ALIVE = 5 # seconds.
1415
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
1516
app = FastAPI()
1617

@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
2021
request_dict = await request.json()
2122
prompt = request_dict.pop("prompt")
2223
sampling_params = SamplingParams(**request_dict)
23-
results_generator = server.generate(prompt, sampling_params)
24+
request_id = random_uuid()
25+
results_generator = server.generate(prompt, sampling_params, request_id)
2426

2527
async def stream_results() -> AsyncGenerator[bytes, None]:
2628
async for request_output in results_generator:
@@ -35,17 +37,24 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
3537
}
3638
yield (json.dumps(ret) + "\0").encode("utf-8")
3739

38-
return StreamingResponse(stream_results())
40+
async def abort_request() -> None:
41+
await server.abort(request_id)
42+
43+
background_tasks = BackgroundTasks()
44+
# Abort the request if the client disconnects.
45+
background_tasks.add_task(abort_request)
46+
return StreamingResponse(stream_results(), background=background_tasks)
3947

4048

4149
if __name__ == "__main__":
4250
parser = argparse.ArgumentParser()
4351
parser.add_argument("--host", type=str, default="localhost")
4452
parser.add_argument("--port", type=int, default=8001)
45-
parser = ServerArgs.add_cli_args(parser)
53+
parser = AsyncServerArgs.add_cli_args(parser)
4654
args = parser.parse_args()
4755

48-
server_args = ServerArgs.from_cli_args(args)
56+
server_args = AsyncServerArgs.from_cli_args(args)
4957
server = AsyncLLMServer.from_server_args(server_args)
5058

51-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
59+
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
60+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

cacheflow/sequence.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
1212
SWAPPED = enum.auto()
1313
FINISHED_STOPPED = enum.auto()
1414
FINISHED_LENGTH_CAPPED = enum.auto()
15+
FINISHED_ABORTED = enum.auto()
1516

1617
@staticmethod
1718
def is_finished(status: "SequenceStatus") -> bool:
1819
return status in [
1920
SequenceStatus.FINISHED_STOPPED,
2021
SequenceStatus.FINISHED_LENGTH_CAPPED,
22+
SequenceStatus.FINISHED_ABORTED,
2123
]
2224

2325
@staticmethod
@@ -26,10 +28,13 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
2628
finish_reason = "stop"
2729
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
2830
finish_reason = "length"
31+
elif status == SequenceStatus.FINISHED_ABORTED:
32+
finish_reason = "abort"
2933
else:
3034
finish_reason = None
3135
return finish_reason
3236

37+
3338
class SequenceData:
3439

3540
def __init__(
@@ -137,6 +142,9 @@ def get_output_token_ids(self) -> List[int]:
137142
def get_cumulative_logprob(self) -> float:
138143
return self.data.cumulative_logprob
139144

145+
def is_finished(self) -> bool:
146+
return SequenceStatus.is_finished(self.status)
147+
140148
def fork(self, child_seq: 'Sequence') -> None:
141149
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
142150
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
@@ -182,7 +190,7 @@ def find(self, seq_id: int) -> Sequence:
182190
raise ValueError(f'Sequence {seq_id} not found.')
183191

184192
def is_finished(self) -> bool:
185-
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
193+
return all(seq.is_finished() for seq in self.seqs)
186194

187195
def __repr__(self) -> str:
188196
return (f"SequenceGroup(request_id={self.request_id}, "

0 commit comments

Comments
 (0)