1
+ import asyncio
1
2
import io
2
3
import threading
3
4
import traceback
5
+ from asyncio import Task
4
6
from datetime import datetime , timezone
5
7
from multiprocessing .pool import AsyncResult , ThreadPool
6
8
from typing import Any , Callable , Dict , Optional , Tuple
@@ -46,41 +48,43 @@ def __init__(
46
48
self ._threadpool = ThreadPool (processes = 1 )
47
49
48
50
self ._response : Optional [schema .PredictionResponse ] = None
49
- self ._result : Optional [AsyncResult ] = None
51
+ self ._result : Optional [Task ] = None
50
52
51
53
self ._worker = Worker (predictor_ref = predictor_ref )
52
54
self ._should_cancel = threading .Event ()
53
55
54
56
self ._shutdown_event = shutdown_event
55
57
self ._upload_url = upload_url
56
58
57
- def setup (self ) -> AsyncResult :
58
- if self . is_busy () :
59
- raise RunnerBusyError ()
60
-
61
- def handle_error ( error : BaseException ) -> None :
59
+ def make_error_handler (self , activity : str ) -> Callable :
60
+ def handle_error ( task : Task ) -> None :
61
+ exc = task . exception ()
62
+ if not exc :
63
+ return
62
64
# Re-raise the exception in order to more easily capture exc_info,
63
65
# and then trigger shutdown, as we have no easy way to resume
64
66
# worker state if an exception was thrown.
65
67
try :
66
- raise error
68
+ raise exc
67
69
except Exception :
68
- log .error ("caught exception while running setup " , exc_info = True )
70
+ log .error (f "caught exception while running { activity } " , exc_info = True )
69
71
if self ._shutdown_event is not None :
70
72
self ._shutdown_event .set ()
71
73
72
- self ._result = self ._threadpool .apply_async (
73
- func = setup ,
74
- kwds = {"worker" : self ._worker },
75
- error_callback = handle_error ,
76
- )
74
+ return handle_error
75
+
76
+ def setup (self ) -> Task ["dict[str, Any]" ]:
77
+ if self .is_busy ():
78
+ raise RunnerBusyError ()
79
+ self ._result = asyncio .create_task (setup (worker = self ._worker ))
80
+ self ._result .add_done_callback (self .make_error_handler ("setup" ))
77
81
return self ._result
78
82
79
83
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
80
84
# no longer have to support Python 3.8
81
85
def predict (
82
86
self , prediction : schema .PredictionRequest , upload : bool = True
83
- ) -> Tuple [schema .PredictionResponse , AsyncResult ]:
87
+ ) -> Tuple [schema .PredictionResponse , Task [ schema . PredictionResponse ] ]:
84
88
# It's the caller's responsibility to not call us if we're busy.
85
89
if self .is_busy ():
86
90
# If self._result is set, but self._response is not, we're still
@@ -101,41 +105,28 @@ def predict(
101
105
upload_url = self ._upload_url if upload else None
102
106
event_handler = create_event_handler (prediction , upload_url = upload_url )
103
107
104
- def cleanup (_ : Optional [ Any ] = None ) -> None :
108
+ def handle_cleanup (_ : Task ) -> None :
105
109
if hasattr (prediction .input , "cleanup" ):
106
110
prediction .input .cleanup ()
107
111
108
- def handle_error (error : BaseException ) -> None :
109
- # Re-raise the exception in order to more easily capture exc_info,
110
- # and then trigger shutdown, as we have no easy way to resume
111
- # worker state if an exception was thrown.
112
- try :
113
- raise error
114
- except Exception :
115
- log .error ("caught exception while running prediction" , exc_info = True )
116
- if self ._shutdown_event is not None :
117
- self ._shutdown_event .set ()
118
-
119
112
self ._response = event_handler .response
120
- self ._result = self ._threadpool .apply_async (
121
- func = predict ,
122
- kwds = {
123
- "worker" : self ._worker ,
124
- "request" : prediction ,
125
- "event_handler" : event_handler ,
126
- "should_cancel" : self ._should_cancel ,
127
- },
128
- callback = cleanup ,
129
- error_callback = handle_error ,
113
+ coro = predict (
114
+ worker = self ._worker ,
115
+ request = prediction ,
116
+ event_handler = event_handler ,
117
+ should_cancel = self ._should_cancel ,
130
118
)
119
+ self ._result = asyncio .create_task (coro )
120
+ self ._result .add_done_callback (handle_cleanup )
121
+ self ._result .add_done_callback (self .make_error_handler ("prediction" ))
131
122
132
123
return (self ._response , self ._result )
133
124
134
125
def is_busy (self ) -> bool :
135
126
if self ._result is None :
136
127
return False
137
128
138
- if not self ._result .ready ():
129
+ if not self ._result .done ():
139
130
return True
140
131
141
132
self ._response = None
@@ -287,12 +278,13 @@ def _upload_files(self, output: Any) -> Any:
287
278
raise FileUploadError ("Got error trying to upload output files" ) from error
288
279
289
280
290
- def setup (* , worker : Worker ) -> Dict [str , Any ]:
281
+ async def setup (* , worker : Worker ) -> Dict [str , Any ]:
291
282
logs = []
292
283
status = None
293
284
started_at = datetime .now (tz = timezone .utc )
294
285
295
286
try :
287
+ # will be async
296
288
for event in worker .setup ():
297
289
if isinstance (event , Log ):
298
290
logs .append (event .message )
@@ -323,7 +315,7 @@ def setup(*, worker: Worker) -> Dict[str, Any]:
323
315
}
324
316
325
317
326
- def predict (
318
+ async def predict (
327
319
* ,
328
320
worker : Worker ,
329
321
request : schema .PredictionRequest ,
@@ -335,7 +327,7 @@ def predict(
335
327
structlog .contextvars .bind_contextvars (prediction_id = request .id )
336
328
337
329
try :
338
- return _predict (
330
+ return await _predict (
339
331
worker = worker ,
340
332
request = request ,
341
333
event_handler = event_handler ,
@@ -348,7 +340,7 @@ def predict(
348
340
raise
349
341
350
342
351
- def _predict (
343
+ async def _predict (
352
344
* ,
353
345
worker : Worker ,
354
346
request : schema .PredictionRequest ,
@@ -370,7 +362,7 @@ def _predict(
370
362
event_handler .failed (error = str (e ))
371
363
log .warn ("failed to download url path from input" , exc_info = True )
372
364
return event_handler .response
373
-
365
+ # will be async
374
366
for event in worker .predict (input_dict , poll = 0.1 ):
375
367
if should_cancel .is_set ():
376
368
worker .cancel ()
0 commit comments