Skip to content

Commit dc5ef44

Browse files
committed
have runner return asyncio.Task instead of AsyncFuture
don't make Worker._wait async for this PR because it's not strictly necessary and makes testing more confusing Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 6c356ce commit dc5ef44

File tree

3 files changed

+45
-51
lines changed

3 files changed

+45
-51
lines changed

python/cog/server/http.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def root() -> Any:
107107

108108
@app.get("/health-check")
109109
async def healthcheck() -> Any:
110-
_check_setup_result()
110+
await _check_setup_result()
111111
if app.state.health == Health.READY:
112112
health = Health.BUSY if runner.is_busy() else Health.READY
113113
else:
@@ -137,7 +137,7 @@ async def predict(request: PredictionRequest = Body(default=None), prefer: Union
137137
# TODO: spec-compliant parsing of Prefer header.
138138
respond_async = prefer == "respond-async"
139139

140-
return _predict(request=request, respond_async=respond_async)
140+
return await _predict(request=request, respond_async=respond_async)
141141

142142
@limited
143143
@app.put(
@@ -172,9 +172,9 @@ async def predict_idempotent(
172172
# TODO: spec-compliant parsing of Prefer header.
173173
respond_async = prefer == "respond-async"
174174

175-
return _predict(request=request, respond_async=respond_async)
175+
return await _predict(request=request, respond_async=respond_async)
176176

177-
def _predict(
177+
async def _predict(
178178
*, request: PredictionRequest, respond_async: bool = False
179179
) -> Response:
180180
# [compat] If no body is supplied, assume that this model can be run
@@ -203,7 +203,8 @@ def _predict(
203203
return JSONResponse(jsonable_encoder(initial_response), status_code=202)
204204

205205
try:
206-
response = PredictionResponse(**async_result.get().dict())
206+
res = await async_result
207+
response = PredictionResponse(res.dict())
207208
except ValidationError as e:
208209
_log_invalid_output(e)
209210
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -239,14 +240,14 @@ async def start_shutdown() -> Any:
239240
shutdown_event.set()
240241
return JSONResponse({}, status_code=200)
241242

242-
def _check_setup_result() -> Any:
243+
async def _check_setup_result() -> Any:
243244
if app.state.setup_result is None:
244245
return
245246

246-
if not app.state.setup_result.ready():
247+
if not app.state.setup_result.done():
247248
return
248249

249-
result = app.state.setup_result.get()
250+
result = await app.state.setup_result
250251

251252
if result["status"] == schema.Status.SUCCEEDED:
252253
app.state.health = Health.READY

python/cog/server/runner.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import io
23
import threading
34
import traceback
5+
from asyncio import Task
46
from datetime import datetime, timezone
57
from multiprocessing.pool import AsyncResult, ThreadPool
68
from typing import Any, Callable, Dict, Optional, Tuple
@@ -46,41 +48,43 @@ def __init__(
4648
self._threadpool = ThreadPool(processes=1)
4749

4850
self._response: Optional[schema.PredictionResponse] = None
49-
self._result: Optional[AsyncResult] = None
51+
self._result: Optional[Task] = None
5052

5153
self._worker = Worker(predictor_ref=predictor_ref)
5254
self._should_cancel = threading.Event()
5355

5456
self._shutdown_event = shutdown_event
5557
self._upload_url = upload_url
5658

57-
def setup(self) -> AsyncResult:
58-
if self.is_busy():
59-
raise RunnerBusyError()
60-
61-
def handle_error(error: BaseException) -> None:
59+
def make_error_handler(self, activity: str) -> Callable:
60+
def handle_error(task: Task) -> None:
61+
exc = task.exception()
62+
if not exc:
63+
return
6264
# Re-raise the exception in order to more easily capture exc_info,
6365
# and then trigger shutdown, as we have no easy way to resume
6466
# worker state if an exception was thrown.
6567
try:
66-
raise error
68+
raise exc
6769
except Exception:
68-
log.error("caught exception while running setup", exc_info=True)
70+
log.error(f"caught exception while running {activity}", exc_info=True)
6971
if self._shutdown_event is not None:
7072
self._shutdown_event.set()
7173

72-
self._result = self._threadpool.apply_async(
73-
func=setup,
74-
kwds={"worker": self._worker},
75-
error_callback=handle_error,
76-
)
74+
return handle_error
75+
76+
def setup(self) -> Task["dict[str, Any]"]:
77+
if self.is_busy():
78+
raise RunnerBusyError()
79+
self._result = asyncio.create_task(setup(worker=self._worker))
80+
self._result.add_done_callback(self.make_error_handler("setup"))
7781
return self._result
7882

7983
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
8084
# no longer have to support Python 3.8
8185
def predict(
8286
self, prediction: schema.PredictionRequest, upload: bool = True
83-
) -> Tuple[schema.PredictionResponse, AsyncResult]:
87+
) -> Tuple[schema.PredictionResponse, Task[schema.PredictionResponse]]:
8488
# It's the caller's responsibility to not call us if we're busy.
8589
if self.is_busy():
8690
# If self._result is set, but self._response is not, we're still
@@ -101,41 +105,28 @@ def predict(
101105
upload_url = self._upload_url if upload else None
102106
event_handler = create_event_handler(prediction, upload_url=upload_url)
103107

104-
def cleanup(_: Optional[Any] = None) -> None:
108+
def handle_cleanup(_: Task) -> None:
105109
if hasattr(prediction.input, "cleanup"):
106110
prediction.input.cleanup()
107111

108-
def handle_error(error: BaseException) -> None:
109-
# Re-raise the exception in order to more easily capture exc_info,
110-
# and then trigger shutdown, as we have no easy way to resume
111-
# worker state if an exception was thrown.
112-
try:
113-
raise error
114-
except Exception:
115-
log.error("caught exception while running prediction", exc_info=True)
116-
if self._shutdown_event is not None:
117-
self._shutdown_event.set()
118-
119112
self._response = event_handler.response
120-
self._result = self._threadpool.apply_async(
121-
func=predict,
122-
kwds={
123-
"worker": self._worker,
124-
"request": prediction,
125-
"event_handler": event_handler,
126-
"should_cancel": self._should_cancel,
127-
},
128-
callback=cleanup,
129-
error_callback=handle_error,
113+
coro = predict(
114+
worker=self._worker,
115+
request=prediction,
116+
event_handler=event_handler,
117+
should_cancel=self._should_cancel,
130118
)
119+
self._result = asyncio.create_task(coro)
120+
self._result.add_done_callback(handle_cleanup)
121+
self._result.add_done_callback(self.make_error_handler("prediction"))
131122

132123
return (self._response, self._result)
133124

134125
def is_busy(self) -> bool:
135126
if self._result is None:
136127
return False
137128

138-
if not self._result.ready():
129+
if not self._result.done():
139130
return True
140131

141132
self._response = None
@@ -287,12 +278,13 @@ def _upload_files(self, output: Any) -> Any:
287278
raise FileUploadError("Got error trying to upload output files") from error
288279

289280

290-
def setup(*, worker: Worker) -> Dict[str, Any]:
281+
async def setup(*, worker: Worker) -> Dict[str, Any]:
291282
logs = []
292283
status = None
293284
started_at = datetime.now(tz=timezone.utc)
294285

295286
try:
287+
# will be async
296288
for event in worker.setup():
297289
if isinstance(event, Log):
298290
logs.append(event.message)
@@ -323,7 +315,7 @@ def setup(*, worker: Worker) -> Dict[str, Any]:
323315
}
324316

325317

326-
def predict(
318+
async def predict(
327319
*,
328320
worker: Worker,
329321
request: schema.PredictionRequest,
@@ -335,7 +327,7 @@ def predict(
335327
structlog.contextvars.bind_contextvars(prediction_id=request.id)
336328

337329
try:
338-
return _predict(
330+
return await _predict(
339331
worker=worker,
340332
request=request,
341333
event_handler=event_handler,
@@ -348,7 +340,7 @@ def predict(
348340
raise
349341

350342

351-
def _predict(
343+
async def _predict(
352344
*,
353345
worker: Worker,
354346
request: schema.PredictionRequest,
@@ -370,7 +362,7 @@ def _predict(
370362
event_handler.failed(error=str(e))
371363
log.warn("failed to download url path from input", exc_info=True)
372364
return event_handler.response
373-
365+
# will be async
374366
for event in worker.predict(input_dict, poll=0.1):
375367
if should_cancel.is_set():
376368
worker.cancel()

python/cog/server/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def _wait(
114114
if send_heartbeats:
115115
yield Heartbeat()
116116
continue
117-
117+
# this needs aioprocessing.Pipe or similar
118+
# multiprocessing.Pipe is not async
118119
ev = self._events.recv()
119120
yield ev
120121

0 commit comments

Comments
 (0)