Skip to content

Commit a6c50fd

Browse files
authored
[Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template
1 parent bc1da87 commit a6c50fd

File tree

13 files changed

+848
-33
lines changed

13 files changed

+848
-33
lines changed

colossalai/inference/batch_bucket.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def is_empty(self):
5959
def current_batch_size(self):
6060
return self._current_batch_size
6161

62+
def __len__(self):
63+
return self._current_batch_size
64+
6265
@property
6366
def available_batch_size(self):
6467
return self.max_batch_size - self._current_batch_size

colossalai/inference/config.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
33
"""
4-
4+
import dataclasses
55
import logging
66
from dataclasses import dataclass
7-
from typing import Optional, Union
7+
from typing import Any, Dict, Optional, Union
88

99
import torch
1010
import torch.distributed as dist
@@ -140,3 +140,18 @@ def to_generation_config(self, model_config) -> GenerationConfig:
140140
meta_config[type] = getattr(model_config, type)
141141

142142
return GenerationConfig.from_dict(meta_config)
143+
144+
@classmethod
145+
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
146+
# Get the list of attributes of this dataclass.
147+
attrs = [attr.name for attr in dataclasses.fields(cls)]
148+
inference_config_args = {}
149+
for attr in attrs:
150+
if attr in config_dict:
151+
inference_config_args[attr] = config_dict[attr]
152+
else:
153+
inference_config_args[attr] = getattr(cls, attr)
154+
155+
# Set the attributes from the parsed arguments.
156+
inference_config = cls(**inference_config_args)
157+
return inference_config
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import asyncio
2+
from functools import partial
3+
from logging import Logger
4+
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
5+
6+
from colossalai.inference.core.engine import InferenceEngine
7+
8+
9+
class AsyncEngineDeadError(RuntimeError):
10+
pass
11+
12+
13+
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
14+
msg = "Task finished unexpectedly. This should never happen! "
15+
try:
16+
try:
17+
task.result()
18+
except asyncio.CancelledError:
19+
return
20+
except Exception as exc:
21+
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
22+
raise AsyncEngineDeadError(msg)
23+
except Exception as exc:
24+
request_tracker.propagate_exception(exc)
25+
raise exc
26+
27+
28+
class AsyncStream:
29+
"""A stream of Output for a request that can be
30+
iterated over asynchronously."""
31+
32+
def __init__(self, request_id: str) -> None:
33+
self.request_id = request_id
34+
self._queue = asyncio.Queue()
35+
self._finished = False
36+
37+
def put(self, item) -> None:
38+
if self._finished:
39+
return
40+
self._queue.put_nowait(item)
41+
42+
def finish(self) -> None:
43+
self._queue.put_nowait(StopIteration)
44+
self._finished = True
45+
46+
@property
47+
def finished(self) -> bool:
48+
return self._finished
49+
50+
def __aiter__(self):
51+
return self
52+
53+
async def __anext__(self):
54+
result = await self._queue.get()
55+
if result is StopIteration:
56+
raise StopAsyncIteration
57+
elif isinstance(result, Exception):
58+
raise result
59+
return result
60+
61+
62+
class RequestTracker:
63+
"""Synchronous abstraction for tracking requests."""
64+
65+
def __init__(self) -> None:
66+
self._request_streams: Dict[str, AsyncStream] = {}
67+
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
68+
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
69+
self.new_requests_event = None
70+
71+
def __contains__(self, item):
72+
return item in self._request_streams
73+
74+
def init_event(self):
75+
self.new_requests_event = asyncio.Event()
76+
77+
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
78+
"""
79+
Propagate an exception to request streams (all if request_id is None).
80+
"""
81+
if request_id is not None:
82+
self._request_streams[request_id].put(exc)
83+
else:
84+
for stream in self._request_streams.values():
85+
stream.put(exc)
86+
87+
def process_finished_request(self, finished_request) -> None:
88+
"""Process a finished request from the engine."""
89+
request_id = finished_request.request_id
90+
91+
self._request_streams[request_id].put(finished_request)
92+
self.abort_request(request_id)
93+
94+
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
95+
"""
96+
Add a request to be sent to the engine on the next background
97+
loop iteration.
98+
"""
99+
if request_id in self._request_streams:
100+
raise KeyError(f"Request {request_id} already exists.")
101+
102+
stream = AsyncStream(request_id)
103+
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
104+
105+
self.new_requests_event.set()
106+
107+
return stream
108+
109+
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
110+
"""Abort a request during next background loop iteration."""
111+
if verbose:
112+
Logger.info(f"Aborted request {request_id}.")
113+
114+
self._finished_requests.put_nowait(request_id)
115+
116+
if request_id not in self._request_streams or self._request_streams[request_id].finished:
117+
# The request has already finished or been aborted.
118+
return
119+
120+
self._request_streams[request_id].finish()
121+
122+
def get_new_requests(self):
123+
"""
124+
Get new requests from http server.
125+
"""
126+
new_requests: List[Dict] = []
127+
128+
while not self._new_requests.empty():
129+
stream, new_request = self._new_requests.get_nowait()
130+
self._request_streams[stream.request_id] = stream
131+
new_requests.append(new_request)
132+
133+
self.new_requests_event.clear()
134+
135+
return new_requests
136+
137+
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
138+
"""Get the new requests and finished requests to be
139+
sent to the engine."""
140+
new_requests: List[Dict] = []
141+
finished_requests: Set[int] = set()
142+
143+
while not self._finished_requests.empty():
144+
request_id = self._finished_requests.get_nowait()
145+
finished_requests.add(request_id)
146+
self._request_streams.pop(request_id, None)
147+
148+
while not self._new_requests.empty():
149+
stream, new_request = self._new_requests.get_nowait()
150+
if stream.request_id in finished_requests:
151+
# The request has already been aborted.
152+
stream.finish()
153+
continue
154+
self._request_streams[stream.request_id] = stream
155+
new_requests.append(new_request)
156+
157+
self.new_requests_event.clear()
158+
159+
return new_requests, finished_requests
160+
161+
async def wait_for_new_requests(self):
162+
await self.new_requests_event.wait()
163+
164+
165+
class _AsyncInferenceEngine(InferenceEngine):
166+
"""
167+
Async methods for Inference Engine.
168+
"""
169+
170+
async def async_step(self) -> List[str]:
171+
"""
172+
The async version of Engine.step()
173+
Performs one decoding iteration and returns newly generated results.
174+
175+
It first schedules the sequences to be executed in the next iteration.
176+
Then, it executes the model and updates the scheduler with the model
177+
outputs. Finally, it decodes the sequences and returns the newly
178+
generated results.
179+
"""
180+
batch = self.request_handler.schedule()
181+
loop = asyncio.get_running_loop()
182+
183+
# Use run_in_executor to asyncally run the sync method model.forward().
184+
logits = await loop.run_in_executor(
185+
None,
186+
self.model,
187+
batch,
188+
self.k_cache,
189+
self.v_cache,
190+
)
191+
192+
if self.inference_config.pad_input:
193+
logits = logits[:, -1, :]
194+
self.request_handler.search_tokens(self.generation_config, logits)
195+
# Return: List[Sequence]
196+
finished_sequences = self.request_handler.update()
197+
198+
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
199+
200+
201+
class AsyncInferenceEngine:
202+
"""An asynchronous wrapper for LLMEngine.
203+
204+
This class is used to wrap the InferenceEngine class to make it asynchronous.
205+
It uses asyncio to create a background loop that keeps processing incoming
206+
requests. The LLMEngine is kicked by the generate method when there are
207+
requests in the waiting queue. The generate method yields the outputs
208+
from the InferenceEngine to the caller.
209+
"""
210+
211+
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
212+
213+
def __init__(self, start_engine_loop: bool = True, **kwargs):
214+
self.engine = self._init_engine(**kwargs)
215+
self.background_loop = None
216+
# reference to the unshielded loop
217+
self._background_loop_unshielded = None
218+
self.start_engine_loop = start_engine_loop
219+
self._request_tracker = RequestTracker()
220+
221+
@property
222+
def background_loop_status(self):
223+
return self.background_loop is not None and not self.background_loop.done()
224+
225+
def start_background_loop(self):
226+
if self.background_loop_status:
227+
raise RuntimeError("Existing loop is running")
228+
229+
self._request_tracker.init_event()
230+
231+
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
232+
self._background_loop_unshielded.add_done_callback(
233+
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
234+
)
235+
self.background_loop = asyncio.shield(self._background_loop_unshielded)
236+
237+
def _init_engine(self, **kwargs):
238+
return self._engine_class(**kwargs)
239+
240+
async def step(self):
241+
"""
242+
Run engine to process requests
243+
244+
Returns True if there are in-progress requests.
245+
"""
246+
new_requests = self._request_tracker.get_new_requests()
247+
for new_request in new_requests:
248+
self.engine.add_single_request(**new_request)
249+
newly_finished_seqs, has_running_requests = await self.engine.async_step()
250+
for seq in newly_finished_seqs:
251+
self._request_tracker.process_finished_request(seq)
252+
253+
return has_running_requests
254+
255+
async def _engine_abort(self, request_ids: Iterable[int]):
256+
self.engine.abort_request(request_ids)
257+
258+
async def abort(self, request_id: int):
259+
"""
260+
Abort a single request
261+
"""
262+
if not self.background_loop_status:
263+
raise RuntimeError("Background loop is not running or launched correctly.")
264+
return self._abort(request_id)
265+
266+
def _abort(self, request_id: int):
267+
self._request_tracker.abort_request(request_id)
268+
269+
async def run_engine_loop(self):
270+
processing_requests = False
271+
while True:
272+
if not processing_requests:
273+
await self._request_tracker.wait_for_new_requests()
274+
processing_requests = await self.step()
275+
await asyncio.sleep(0)
276+
277+
async def add_request(
278+
self,
279+
request_id: int,
280+
prompt: Optional[str],
281+
prompt_token_ids: Optional[List[int]] = None,
282+
) -> AsyncStream:
283+
"""
284+
Add a request to the background tracker(waitting queue), start the background loop if needed.
285+
"""
286+
if not self.background_loop_status:
287+
if self.start_engine_loop:
288+
self.start_background_loop()
289+
else:
290+
raise RuntimeError("Background loop is not running.")
291+
stream = self._request_tracker.add_request(
292+
request_id,
293+
prompt=prompt,
294+
prompt_token_ids=prompt_token_ids,
295+
)
296+
return stream
297+
298+
async def generate(
299+
self,
300+
request_id: int,
301+
prompt: Optional[str],
302+
prompt_token_ids: Optional[List[int]] = None,
303+
) -> AsyncIterator[str]:
304+
"""
305+
Generate output from a request. It receives the request from http server, adds it into the
306+
waitting queue of Async Engine and streams the output sequence.
307+
308+
"""
309+
try:
310+
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
311+
async for request_output in stream:
312+
yield request_output
313+
314+
except (Exception, asyncio.CancelledError) as e:
315+
# If there is an exception or coroutine is cancelled, abort the
316+
# request.
317+
self._abort(request_id)
318+
raise e

0 commit comments

Comments
 (0)