Skip to content

Commit ddc2339

Browse files
committed
async runner (#1352)
* have runner return asyncio.Task instead of AsyncFuture * make tests async and fix them * delete remaining runner thread code :) * review changes to tests and server (reverts commit 828eee9) Signed-off-by: technillogue <technillogue@gmail.com>
1 parent eae972d commit ddc2339

File tree

5 files changed

+89
-89
lines changed

5 files changed

+89
-89
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ optional-dependencies = { "dev" = [
3535
"pillow",
3636
"pyright==1.1.347",
3737
"pytest",
38+
"pytest-asyncio",
3839
"pytest-httpserver",
3940
"pytest-rerunfailures",
4041
"pytest-xdist",

python/cog/server/http.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def root() -> Any:
257257

258258
@app.get("/health-check")
259259
async def healthcheck() -> Any:
260-
_check_setup_result()
260+
await _check_setup_task()
261261
if app.state.health == Health.READY:
262262
health = Health.BUSY if runner.is_busy() else Health.READY
263263
else:
@@ -291,7 +291,7 @@ async def predict(
291291
with trace_context(make_trace_context(traceparent, tracestate)):
292292
return _predict(
293293
request=request,
294-
respond_async=respond_async,
294+
respond_async=respond_async
295295
)
296296

297297
@limited
@@ -335,10 +335,9 @@ async def predict_idempotent(
335335
respond_async=respond_async,
336336
)
337337

338-
def _predict(
339-
*,
340-
request: Optional[PredictionRequest],
341-
respond_async: bool = False,
338+
339+
async def _predict(
340+
*, request: Optional[PredictionRequest], respond_async: bool = False
342341
) -> Response:
343342
# [compat] If no body is supplied, assume that this model can be run
344343
# with empty input. This will throw a ValidationError if that's not
@@ -367,7 +366,8 @@ def _predict(
367366
return JSONResponse(jsonable_encoder(initial_response), status_code=202)
368367

369368
try:
370-
response = PredictionResponse(**async_result.get().dict())
369+
prediction = await async_result
370+
response = PredictionResponse(**prediction.dict())
371371
except ValidationError as e:
372372
_log_invalid_output(e)
373373
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -396,14 +396,15 @@ async def cancel(prediction_id: str = Path(..., title="Prediction ID")) -> Any:
396396
else:
397397
return JSONResponse({}, status_code=200)
398398

399-
def _check_setup_result() -> Any:
399+
async def _check_setup_task() -> Any:
400400
if app.state.setup_task is None:
401401
return
402402

403-
if not app.state.setup_task.ready():
403+
if not app.state.setup_task.done():
404404
return
405405

406-
result = app.state.setup_task.get()
406+
# this can raise CancelledError
407+
result = app.state.setup_task.result()
407408

408409
if result.status == schema.Status.SUCCEEDED:
409410
app.state.health = Health.READY

python/cog/server/runner.py

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
import asyncio
12
import io
2-
import sys
33
import threading
44
import traceback
55
import typing # TypeAlias, py3.10
66
from datetime import datetime, timezone
7-
from multiprocessing.pool import AsyncResult, ThreadPool
87
from typing import Any, Callable, Optional, Tuple, Union, cast
98

109
import requests
@@ -46,11 +45,8 @@ class SetupResult:
4645
status: schema.Status
4746

4847

49-
PredictionTask: "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
50-
SetupTask: "typing.TypeAlias" = "AsyncResult[SetupResult]"
51-
if sys.version_info < (3, 9):
52-
PredictionTask = AsyncResult
53-
SetupTask = AsyncResult
48+
PredictionTask: "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
49+
SetupTask: "typing.TypeAlias" = "asyncio.Task[SetupResult]"
5450
RunnerTask: "typing.TypeAlias" = Union[PredictionTask, SetupTask]
5551

5652

@@ -62,38 +58,37 @@ def __init__(
6258
shutdown_event: Optional[threading.Event],
6359
upload_url: Optional[str] = None,
6460
) -> None:
65-
self._thread = None
66-
self._threadpool = ThreadPool(processes=1)
67-
6861
self._response: Optional[schema.PredictionResponse] = None
6962
self._result: Optional[RunnerTask] = None
7063

7164
self._worker = Worker(predictor_ref=predictor_ref)
72-
self._should_cancel = threading.Event()
65+
self._should_cancel = asyncio.Event()
7366

7467
self._shutdown_event = shutdown_event
7568
self._upload_url = upload_url
7669

77-
def setup(self) -> SetupTask:
78-
if self.is_busy():
79-
raise RunnerBusyError()
80-
81-
def handle_error(error: BaseException) -> None:
70+
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
71+
def handle_error(task: RunnerTask) -> None:
72+
exc = task.exception()
73+
if not exc:
74+
return
8275
# Re-raise the exception in order to more easily capture exc_info,
8376
# and then trigger shutdown, as we have no easy way to resume
8477
# worker state if an exception was thrown.
8578
try:
86-
raise error
79+
raise exc
8780
except Exception:
88-
log.error("caught exception while running setup", exc_info=True)
81+
log.error(f"caught exception while running {activity}", exc_info=True)
8982
if self._shutdown_event is not None:
9083
self._shutdown_event.set()
9184

92-
self._result = self._threadpool.apply_async(
93-
func=setup,
94-
kwds={"worker": self._worker},
95-
error_callback=handle_error,
96-
)
85+
return handle_error
86+
87+
def setup(self) -> SetupTask:
88+
if self.is_busy():
89+
raise RunnerBusyError()
90+
self._result = asyncio.create_task(setup(worker=self._worker))
91+
self._result.add_done_callback(self.make_error_handler("setup"))
9792
return self._result
9893

9994
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
@@ -127,52 +122,39 @@ def predict(
127122
upload_url=upload_url,
128123
)
129124

130-
def cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
125+
def handle_cleanup(_: Optional[schema.PredictionResponse] = None) -> None:
131126
input = cast(Any, prediction.input)
132127
if hasattr(input, "cleanup"):
133128
input.cleanup()
134129

135-
def handle_error(error: BaseException) -> None:
136-
# Re-raise the exception in order to more easily capture exc_info,
137-
# and then trigger shutdown, as we have no easy way to resume
138-
# worker state if an exception was thrown.
139-
try:
140-
raise error
141-
except Exception:
142-
log.error("caught exception while running prediction", exc_info=True)
143-
if self._shutdown_event is not None:
144-
self._shutdown_event.set()
145-
146130
self._response = event_handler.response
147-
self._result = self._threadpool.apply_async(
148-
func=predict,
149-
kwds={
150-
"worker": self._worker,
151-
"request": prediction,
152-
"event_handler": event_handler,
153-
"should_cancel": self._should_cancel,
154-
},
155-
callback=cleanup,
156-
error_callback=handle_error,
131+
coro = predict(
132+
worker=self._worker,
133+
request=prediction,
134+
event_handler=event_handler,
135+
should_cancel=self._should_cancel,
157136
)
137+
self._result = asyncio.create_task(coro)
138+
self._result.add_done_callback(handle_cleanup)
139+
self._result.add_done_callback(self.make_error_handler("prediction"))
158140

159141
return (self._response, self._result)
160142

161143
def is_busy(self) -> bool:
162144
if self._result is None:
163145
return False
164146

165-
if not self._result.ready():
147+
if not self._result.done():
166148
return True
167149

168150
self._response = None
169151
self._result = None
170152
return False
171153

172154
def shutdown(self) -> None:
155+
if self._result:
156+
self._result.cancel()
173157
self._worker.terminate()
174-
self._threadpool.terminate()
175-
self._threadpool.join()
176158

177159
def cancel(self, prediction_id: Optional[str] = None) -> None:
178160
if not self.is_busy():
@@ -318,13 +300,15 @@ def _upload_files(self, output: Any) -> Any:
318300
raise FileUploadError("Got error trying to upload output files") from error
319301

320302

321-
def setup(*, worker: Worker) -> SetupResult:
303+
async def setup(*, worker: Worker) -> SetupResult:
322304
logs = []
323305
status = None
324306
started_at = datetime.now(tz=timezone.utc)
325307

326308
try:
309+
# will be async
327310
for event in worker.setup():
311+
await asyncio.sleep(0)
328312
if isinstance(event, Log):
329313
logs.append(event.message)
330314
elif isinstance(event, Done):
@@ -354,19 +338,19 @@ def setup(*, worker: Worker) -> SetupResult:
354338
)
355339

356340

357-
def predict(
341+
async def predict(
358342
*,
359343
worker: Worker,
360344
request: schema.PredictionRequest,
361345
event_handler: PredictionEventHandler,
362-
should_cancel: threading.Event,
346+
should_cancel: asyncio.Event,
363347
) -> schema.PredictionResponse:
364348
# Set up logger context within prediction thread.
365349
structlog.contextvars.clear_contextvars()
366350
structlog.contextvars.bind_contextvars(prediction_id=request.id)
367351

368352
try:
369-
return _predict(
353+
return await _predict(
370354
worker=worker,
371355
request=request,
372356
event_handler=event_handler,
@@ -379,12 +363,12 @@ def predict(
379363
raise
380364

381365

382-
def _predict(
366+
async def _predict(
383367
*,
384368
worker: Worker,
385369
request: schema.PredictionRequest,
386370
event_handler: PredictionEventHandler,
387-
should_cancel: threading.Event,
371+
should_cancel: asyncio.Event,
388372
) -> schema.PredictionResponse:
389373
initial_prediction = request.dict()
390374

@@ -408,7 +392,9 @@ def _predict(
408392
log.warn("Failed to download url path from input", exc_info=True)
409393
return event_handler.response
410394

395+
# will be async
411396
for event in worker.predict(input_dict, poll=0.1):
397+
await asyncio.sleep(0)
412398
if should_cancel.is_set():
413399
worker.cancel()
414400
should_cancel.clear()

python/cog/server/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def _wait(
118118
if send_heartbeats:
119119
yield Heartbeat()
120120
continue
121-
121+
# this needs aioprocessing.Pipe or similar
122+
# multiprocessing.Pipe is not async
122123
ev = self._events.recv()
123124
yield ev
124125

0 commit comments

Comments
 (0)