Skip to content

Commit 0fb45ec

Browse files
committed
async runner (#1352)
* have runner return asyncio.Task instead of AsyncFuture * make tests async and fix them * delete remaining runner thread code :) * review changes to tests and server (reverts commit 828eee9)
1 parent 0df9b82 commit 0fb45ec

File tree

5 files changed

+89
-89
lines changed

5 files changed

+89
-89
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
3535
"pillow",
3636
"pyright==1.1.347",
3737
"pytest",
38+
"pytest-asyncio",
3839
"pytest-httpserver",
3940
"pytest-rerunfailures",
4041
"pytest-xdist",

python/cog/server/http.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ async def root() -> Any:
233233

234234
@app.get("/health-check")
235235
async def healthcheck() -> Any:
236-
_check_setup_result()
236+
await _check_setup_task()
237237
if app.state.health == Health.READY:
238238
health = Health.BUSY if runner.is_busy() else Health.READY
239239
else:
@@ -259,7 +259,7 @@ async def predict(request: PredictionRequest = Body(default=None), prefer: Union
259259
# TODO: spec-compliant parsing of Prefer header.
260260
respond_async = prefer == "respond-async"
261261

262-
return _predict(request=request, respond_async=respond_async)
262+
return await _predict(request=request, respond_async=respond_async)
263263

264264
@limited
265265
@app.put(
@@ -294,10 +294,10 @@ async def predict_idempotent(
294294
# TODO: spec-compliant parsing of Prefer header.
295295
respond_async = prefer == "respond-async"
296296

297-
return _predict(request=request, respond_async=respond_async)
297+
return await _predict(request=request, respond_async=respond_async)
298298

299-
def _predict(
300-
*, request: PredictionRequest, respond_async: bool = False
299+
async def _predict(
300+
*, request: Optional[PredictionRequest], respond_async: bool = False
301301
) -> Response:
302302
# [compat] If no body is supplied, assume that this model can be run
303303
# with empty input. This will throw a ValidationError if that's not
@@ -325,7 +325,8 @@ def _predict(
325325
return JSONResponse(jsonable_encoder(initial_response), status_code=202)
326326

327327
try:
328-
response = PredictionResponse(**async_result.get().dict())
328+
prediction = await async_result
329+
response = PredictionResponse(**prediction.dict())
329330
except ValidationError as e:
330331
_log_invalid_output(e)
331332
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -354,14 +355,15 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
354355
else:
355356
return JSONResponse({}, status_code=200)
356357

357-
def _check_setup_result() -> Any:
358+
async def _check_setup_task() -> Any:
358359
if app.state.setup_task is None:
359360
return
360361

361-
if not app.state.setup_task.ready():
362+
if not app.state.setup_task.done():
362363
return
363364

364-
result = app.state.setup_task.get()
365+
# this can raise CancelledError
366+
result = app.state.setup_task.result()
365367

366368
if result.status == schema.Status.SUCCEEDED:
367369
app.state.health = Health.READY

python/cog/server/runner.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
import asyncio
12
import io
2-
import sys
33
import threading
44
import traceback
55
import typing # TypeAlias, py3.10
66
from datetime import datetime, timezone
7-
from multiprocessing.pool import AsyncResult, ThreadPool
87
from typing import Any, Callable, Optional, Tuple, Union, cast
98

109
import requests
@@ -45,11 +44,8 @@ class SetupResult:
4544
status: schema.Status
4645

4746

48-
PredictionTask: "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
49-
SetupTask: "typing.TypeAlias" = "AsyncResult[SetupResult]"
50-
if sys.version_info < (3, 9):
51-
PredictionTask = AsyncResult
52-
SetupTask = AsyncResult
47+
PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
48+
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
5349
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]
5450

5551

@@ -61,38 +57,37 @@ def __init__(
6157
shutdown_event: Optional[threading.Event],
6258
upload_url: Optional[str] = None,
6359
) -> None:
64-
self._thread = None
65-
self._threadpool = ThreadPool(processes=1)
66-
6760
self._response: Optional[schema.PredictionResponse] = None
6861
self._result: Optional[RunnerTask] = None
6962

7063
self._worker = Worker(predictor_ref=predictor_ref)
71-
self._should_cancel = threading.Event()
64+
self._should_cancel = asyncio.Event()
7265

7366
self._shutdown_event = shutdown_event
7467
self._upload_url = upload_url
7568

76-
def setup(self) -> SetupTask:
77-
if self.is_busy():
78-
raise RunnerBusyError()
79-
80-
def handle_error(error: BaseException) -> None:
69+
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
70+
def handle_error(task: RunnerTask) -> None:
71+
exc = task.exception()
72+
if not exc:
73+
return
8174
# Re-raise the exception in order to more easily capture exc_info,
8275
# and then trigger shutdown, as we have no easy way to resume
8376
# worker state if an exception was thrown.
8477
try:
85-
raise error
78+
raise exc
8679
except Exception:
87-
log.error("caught exception while running setup", exc_info=True)
80+
log.error(f"caught exception while running {activity}", exc_info=True)
8881
if self._shutdown_event is not None:
8982
self._shutdown_event.set()
9083

91-
self._result = self._threadpool.apply_async(
92-
func=setup,
93-
kwds={"worker": self._worker},
94-
error_callback=handle_error,
95-
)
84+
return handle_error
85+
86+
def setup(self) -> SetupTask:
87+
if self.is_busy():
88+
raise RunnerBusyError()
89+
self._result = asyncio.create_task(setup(worker=self._worker))
90+
self._result.add_done_callback(self.make_error_handler("setup"))
9691
return self._result
9792

9893
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
@@ -121,52 +116,39 @@ def predict(
121116
upload_url = self._upload_url if upload else None
122117
event_handler = create_event_handler(prediction, upload_url=upload_url)
123118

124-
def cleanup(_: schema.PredictionResponse = None) -> None:
119+
def handle_cleanup(_: PredictionTask) -> None:
125120
input = cast(Any, prediction.input)
126121
if hasattr(input, "cleanup"):
127122
input.cleanup()
128123

129-
def handle_error(error: BaseException) -> None:
130-
# Re-raise the exception in order to more easily capture exc_info,
131-
# and then trigger shutdown, as we have no easy way to resume
132-
# worker state if an exception was thrown.
133-
try:
134-
raise error
135-
except Exception:
136-
log.error("caught exception while running prediction", exc_info=True)
137-
if self._shutdown_event is not None:
138-
self._shutdown_event.set()
139-
140124
self._response = event_handler.response
141-
self._result = self._threadpool.apply_async(
142-
func=predict,
143-
kwds={
144-
"worker": self._worker,
145-
"request": prediction,
146-
"event_handler": event_handler,
147-
"should_cancel": self._should_cancel,
148-
},
149-
callback=cleanup,
150-
error_callback=handle_error,
125+
coro = predict(
126+
worker=self._worker,
127+
request=prediction,
128+
event_handler=event_handler,
129+
should_cancel=self._should_cancel,
151130
)
131+
self._result = asyncio.create_task(coro)
132+
self._result.add_done_callback(handle_cleanup)
133+
self._result.add_done_callback(self.make_error_handler("prediction"))
152134

153135
return (self._response, self._result)
154136

155137
def is_busy(self) -> bool:
156138
if self._result is None:
157139
return False
158140

159-
if not self._result.ready():
141+
if not self._result.done():
160142
return True
161143

162144
self._response = None
163145
self._result = None
164146
return False
165147

166148
def shutdown(self) -> None:
149+
if self._result:
150+
self._result.cancel()
167151
self._worker.terminate()
168-
self._threadpool.terminate()
169-
self._threadpool.join()
170152

171153
def cancel(self, prediction_id: Optional[str] = None) -> None:
172154
if not self.is_busy():
@@ -308,13 +290,15 @@ def _upload_files(self, output: Any) -> Any:
308290
raise FileUploadError("Got error trying to upload output files") from error
309291

310292

311-
def setup(*, worker: Worker) -> SetupResult:
293+
async def setup(*, worker: Worker) -> SetupResult:
312294
logs = []
313295
status = None
314296
started_at = datetime.now(tz=timezone.utc)
315297

316298
try:
299+
# will be async
317300
for event in worker.setup():
301+
await asyncio.sleep(0)
318302
if isinstance(event, Log):
319303
logs.append(event.message)
320304
elif isinstance(event, Done):
@@ -344,19 +328,19 @@ def setup(*, worker: Worker) -> SetupResult:
344328
)
345329

346330

347-
def predict(
331+
async def predict(
348332
*,
349333
worker: Worker,
350334
request: schema.PredictionRequest,
351335
event_handler: PredictionEventHandler,
352-
should_cancel: threading.Event,
336+
should_cancel: asyncio.Event,
353337
) -> schema.PredictionResponse:
354338
# Set up logger context within prediction thread.
355339
structlog.contextvars.clear_contextvars()
356340
structlog.contextvars.bind_contextvars(prediction_id=request.id)
357341

358342
try:
359-
return _predict(
343+
return await _predict(
360344
worker=worker,
361345
request=request,
362346
event_handler=event_handler,
@@ -369,12 +353,12 @@ def predict(
369353
raise
370354

371355

372-
def _predict(
356+
async def _predict(
373357
*,
374358
worker: Worker,
375359
request: schema.PredictionRequest,
376360
event_handler: PredictionEventHandler,
377-
should_cancel: threading.Event,
361+
should_cancel: asyncio.Event,
378362
) -> schema.PredictionResponse:
379363
initial_prediction = request.dict()
380364

@@ -391,8 +375,9 @@ def _predict(
391375
event_handler.failed(error=str(e))
392376
log.warn("failed to download url path from input", exc_info=True)
393377
return event_handler.response
394-
378+
# will be async
395379
for event in worker.predict(input_dict, poll=0.1):
380+
await asyncio.sleep(0)
396381
if should_cancel.is_set():
397382
worker.cancel()
398383
should_cancel.clear()

python/cog/server/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def _wait(
114114
if send_heartbeats:
115115
yield Heartbeat()
116116
continue
117-
117+
# this needs aioprocessing.Pipe or similar
118+
# multiprocessing.Pipe is not async
118119
ev = self._events.recv()
119120
yield ev
120121

0 commit comments

Comments
 (0)