1
+ import asyncio
1
2
import io
2
- import sys
3
3
import threading
4
4
import traceback
5
5
import typing # TypeAlias, py3.10
6
6
from datetime import datetime , timezone
7
- from multiprocessing .pool import AsyncResult , ThreadPool
8
7
from typing import Any , Callable , Optional , Tuple , Union , cast
9
8
10
9
import requests
@@ -45,11 +44,8 @@ class SetupResult:
45
44
status : schema .Status
46
45
47
46
48
- PredictionTask : "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
49
- SetupTask : "typing.TypeAlias" = "AsyncResult[SetupResult]"
50
- if sys .version_info < (3 , 9 ):
51
- PredictionTask = AsyncResult
52
- SetupTask = AsyncResult
47
+ PredictionTask : "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
48
+ SetupTask : "typing.TypeAlias" = "asyncio.Task[SetupResult]"
53
49
RunnerTask : "typing.TypeAlias" = Union [PredictionTask , SetupTask ]
54
50
55
51
@@ -61,38 +57,37 @@ def __init__(
61
57
shutdown_event : Optional [threading .Event ],
62
58
upload_url : Optional [str ] = None ,
63
59
) -> None :
64
- self ._thread = None
65
- self ._threadpool = ThreadPool (processes = 1 )
66
-
67
60
self ._response : Optional [schema .PredictionResponse ] = None
68
61
self ._result : Optional [RunnerTask ] = None
69
62
70
63
self ._worker = Worker (predictor_ref = predictor_ref )
71
- self ._should_cancel = threading .Event ()
64
+ self ._should_cancel = asyncio .Event ()
72
65
73
66
self ._shutdown_event = shutdown_event
74
67
self ._upload_url = upload_url
75
68
76
- def setup (self ) -> SetupTask :
77
- if self . is_busy () :
78
- raise RunnerBusyError ()
79
-
80
- def handle_error ( error : BaseException ) -> None :
69
+ def make_error_handler (self , activity : str ) -> Callable [[ RunnerTask ], None ] :
70
+ def handle_error ( task : RunnerTask ) -> None :
71
+ exc = task . exception ()
72
+ if not exc :
73
+ return
81
74
# Re-raise the exception in order to more easily capture exc_info,
82
75
# and then trigger shutdown, as we have no easy way to resume
83
76
# worker state if an exception was thrown.
84
77
try :
85
- raise error
78
+ raise exc
86
79
except Exception :
87
- log .error ("caught exception while running setup " , exc_info = True )
80
+ log .error (f "caught exception while running { activity } " , exc_info = True )
88
81
if self ._shutdown_event is not None :
89
82
self ._shutdown_event .set ()
90
83
91
- self ._result = self ._threadpool .apply_async (
92
- func = setup ,
93
- kwds = {"worker" : self ._worker },
94
- error_callback = handle_error ,
95
- )
84
+ return handle_error
85
+
86
+ def setup (self ) -> SetupTask :
87
+ if self .is_busy ():
88
+ raise RunnerBusyError ()
89
+ self ._result = asyncio .create_task (setup (worker = self ._worker ))
90
+ self ._result .add_done_callback (self .make_error_handler ("setup" ))
96
91
return self ._result
97
92
98
93
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
@@ -121,52 +116,39 @@ def predict(
121
116
upload_url = self ._upload_url if upload else None
122
117
event_handler = create_event_handler (prediction , upload_url = upload_url )
123
118
124
- def cleanup (_ : schema . PredictionResponse = None ) -> None :
119
+ def handle_cleanup (_ : PredictionTask ) -> None :
125
120
input = cast (Any , prediction .input )
126
121
if hasattr (input , "cleanup" ):
127
122
input .cleanup ()
128
123
129
- def handle_error (error : BaseException ) -> None :
130
- # Re-raise the exception in order to more easily capture exc_info,
131
- # and then trigger shutdown, as we have no easy way to resume
132
- # worker state if an exception was thrown.
133
- try :
134
- raise error
135
- except Exception :
136
- log .error ("caught exception while running prediction" , exc_info = True )
137
- if self ._shutdown_event is not None :
138
- self ._shutdown_event .set ()
139
-
140
124
self ._response = event_handler .response
141
- self ._result = self ._threadpool .apply_async (
142
- func = predict ,
143
- kwds = {
144
- "worker" : self ._worker ,
145
- "request" : prediction ,
146
- "event_handler" : event_handler ,
147
- "should_cancel" : self ._should_cancel ,
148
- },
149
- callback = cleanup ,
150
- error_callback = handle_error ,
125
+ coro = predict (
126
+ worker = self ._worker ,
127
+ request = prediction ,
128
+ event_handler = event_handler ,
129
+ should_cancel = self ._should_cancel ,
151
130
)
131
+ self ._result = asyncio .create_task (coro )
132
+ self ._result .add_done_callback (handle_cleanup )
133
+ self ._result .add_done_callback (self .make_error_handler ("prediction" ))
152
134
153
135
return (self ._response , self ._result )
154
136
155
137
def is_busy (self ) -> bool :
156
138
if self ._result is None :
157
139
return False
158
140
159
- if not self ._result .ready ():
141
+ if not self ._result .done ():
160
142
return True
161
143
162
144
self ._response = None
163
145
self ._result = None
164
146
return False
165
147
166
148
def shutdown (self ) -> None :
149
+ if self ._result :
150
+ self ._result .cancel ()
167
151
self ._worker .terminate ()
168
- self ._threadpool .terminate ()
169
- self ._threadpool .join ()
170
152
171
153
def cancel (self , prediction_id : Optional [str ] = None ) -> None :
172
154
if not self .is_busy ():
@@ -308,13 +290,15 @@ def _upload_files(self, output: Any) -> Any:
308
290
raise FileUploadError ("Got error trying to upload output files" ) from error
309
291
310
292
311
- def setup (* , worker : Worker ) -> SetupResult :
293
+ async def setup (* , worker : Worker ) -> SetupResult :
312
294
logs = []
313
295
status = None
314
296
started_at = datetime .now (tz = timezone .utc )
315
297
316
298
try :
299
+ # will be async
317
300
for event in worker .setup ():
301
+ await asyncio .sleep (0 )
318
302
if isinstance (event , Log ):
319
303
logs .append (event .message )
320
304
elif isinstance (event , Done ):
@@ -344,19 +328,19 @@ def setup(*, worker: Worker) -> SetupResult:
344
328
)
345
329
346
330
347
- def predict (
331
+ async def predict (
348
332
* ,
349
333
worker : Worker ,
350
334
request : schema .PredictionRequest ,
351
335
event_handler : PredictionEventHandler ,
352
- should_cancel : threading .Event ,
336
+ should_cancel : asyncio .Event ,
353
337
) -> schema .PredictionResponse :
354
338
# Set up logger context within prediction thread.
355
339
structlog .contextvars .clear_contextvars ()
356
340
structlog .contextvars .bind_contextvars (prediction_id = request .id )
357
341
358
342
try :
359
- return _predict (
343
+ return await _predict (
360
344
worker = worker ,
361
345
request = request ,
362
346
event_handler = event_handler ,
@@ -369,12 +353,12 @@ def predict(
369
353
raise
370
354
371
355
372
- def _predict (
356
+ async def _predict (
373
357
* ,
374
358
worker : Worker ,
375
359
request : schema .PredictionRequest ,
376
360
event_handler : PredictionEventHandler ,
377
- should_cancel : threading .Event ,
361
+ should_cancel : asyncio .Event ,
378
362
) -> schema .PredictionResponse :
379
363
initial_prediction = request .dict ()
380
364
@@ -391,8 +375,9 @@ def _predict(
391
375
event_handler .failed (error = str (e ))
392
376
log .warn ("failed to download url path from input" , exc_info = True )
393
377
return event_handler .response
394
-
378
+ # will be async
395
379
for event in worker .predict (input_dict , poll = 0.1 ):
380
+ await asyncio .sleep (0 )
396
381
if should_cancel .is_set ():
397
382
worker .cancel ()
398
383
should_cancel .clear ()
0 commit comments