1
+ import asyncio
1
2
import io
2
- import sys
3
3
import threading
4
4
import traceback
5
5
import typing # TypeAlias, py3.10
6
6
from datetime import datetime , timezone
7
- from multiprocessing .pool import AsyncResult , ThreadPool
8
7
from typing import Any , Callable , Optional , Tuple , Union , cast
9
8
10
9
import requests
@@ -46,11 +45,8 @@ class SetupResult:
46
45
status : schema .Status
47
46
48
47
49
- PredictionTask : "typing.TypeAlias" = "AsyncResult[schema.PredictionResponse]"
50
- SetupTask : "typing.TypeAlias" = "AsyncResult[SetupResult]"
51
- if sys .version_info < (3 , 9 ):
52
- PredictionTask = AsyncResult
53
- SetupTask = AsyncResult
48
+ PredictionTask : "typing.TypeAlias" = "asyncio.Task[schema.PredictionResponse]"
49
+ SetupTask : "typing.TypeAlias" = "asyncio.Task[SetupResult]"
54
50
RunnerTask : "typing.TypeAlias" = Union [PredictionTask , SetupTask ]
55
51
56
52
@@ -62,38 +58,37 @@ def __init__(
62
58
shutdown_event : Optional [threading .Event ],
63
59
upload_url : Optional [str ] = None ,
64
60
) -> None :
65
- self ._thread = None
66
- self ._threadpool = ThreadPool (processes = 1 )
67
-
68
61
self ._response : Optional [schema .PredictionResponse ] = None
69
62
self ._result : Optional [RunnerTask ] = None
70
63
71
64
self ._worker = Worker (predictor_ref = predictor_ref )
72
- self ._should_cancel = threading .Event ()
65
+ self ._should_cancel = asyncio .Event ()
73
66
74
67
self ._shutdown_event = shutdown_event
75
68
self ._upload_url = upload_url
76
69
77
- def setup (self ) -> SetupTask :
78
- if self . is_busy () :
79
- raise RunnerBusyError ()
80
-
81
- def handle_error ( error : BaseException ) -> None :
70
+ def make_error_handler (self , activity : str ) -> Callable [[ RunnerTask ], None ] :
71
+ def handle_error ( task : RunnerTask ) -> None :
72
+ exc = task . exception ()
73
+ if not exc :
74
+ return
82
75
# Re-raise the exception in order to more easily capture exc_info,
83
76
# and then trigger shutdown, as we have no easy way to resume
84
77
# worker state if an exception was thrown.
85
78
try :
86
- raise error
79
+ raise exc
87
80
except Exception :
88
- log .error ("caught exception while running setup " , exc_info = True )
81
+ log .error (f "caught exception while running { activity } " , exc_info = True )
89
82
if self ._shutdown_event is not None :
90
83
self ._shutdown_event .set ()
91
84
92
- self ._result = self ._threadpool .apply_async (
93
- func = setup ,
94
- kwds = {"worker" : self ._worker },
95
- error_callback = handle_error ,
96
- )
85
+ return handle_error
86
+
87
+ def setup (self ) -> SetupTask :
88
+ if self .is_busy ():
89
+ raise RunnerBusyError ()
90
+ self ._result = asyncio .create_task (setup (worker = self ._worker ))
91
+ self ._result .add_done_callback (self .make_error_handler ("setup" ))
97
92
return self ._result
98
93
99
94
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
@@ -127,52 +122,39 @@ def predict(
127
122
upload_url = upload_url ,
128
123
)
129
124
130
- def cleanup (_ : Optional [schema .PredictionResponse ] = None ) -> None :
125
+ def handle_cleanup (_ : Optional [schema .PredictionResponse ] = None ) -> None :
131
126
input = cast (Any , prediction .input )
132
127
if hasattr (input , "cleanup" ):
133
128
input .cleanup ()
134
129
135
- def handle_error (error : BaseException ) -> None :
136
- # Re-raise the exception in order to more easily capture exc_info,
137
- # and then trigger shutdown, as we have no easy way to resume
138
- # worker state if an exception was thrown.
139
- try :
140
- raise error
141
- except Exception :
142
- log .error ("caught exception while running prediction" , exc_info = True )
143
- if self ._shutdown_event is not None :
144
- self ._shutdown_event .set ()
145
-
146
130
self ._response = event_handler .response
147
- self ._result = self ._threadpool .apply_async (
148
- func = predict ,
149
- kwds = {
150
- "worker" : self ._worker ,
151
- "request" : prediction ,
152
- "event_handler" : event_handler ,
153
- "should_cancel" : self ._should_cancel ,
154
- },
155
- callback = cleanup ,
156
- error_callback = handle_error ,
131
+ coro = predict (
132
+ worker = self ._worker ,
133
+ request = prediction ,
134
+ event_handler = event_handler ,
135
+ should_cancel = self ._should_cancel ,
157
136
)
137
+ self ._result = asyncio .create_task (coro )
138
+ self ._result .add_done_callback (handle_cleanup )
139
+ self ._result .add_done_callback (self .make_error_handler ("prediction" ))
158
140
159
141
return (self ._response , self ._result )
160
142
161
143
def is_busy (self ) -> bool :
162
144
if self ._result is None :
163
145
return False
164
146
165
- if not self ._result .ready ():
147
+ if not self ._result .done ():
166
148
return True
167
149
168
150
self ._response = None
169
151
self ._result = None
170
152
return False
171
153
172
154
def shutdown (self ) -> None :
155
+ if self ._result :
156
+ self ._result .cancel ()
173
157
self ._worker .terminate ()
174
- self ._threadpool .terminate ()
175
- self ._threadpool .join ()
176
158
177
159
def cancel (self , prediction_id : Optional [str ] = None ) -> None :
178
160
if not self .is_busy ():
@@ -318,13 +300,15 @@ def _upload_files(self, output: Any) -> Any:
318
300
raise FileUploadError ("Got error trying to upload output files" ) from error
319
301
320
302
321
- def setup (* , worker : Worker ) -> SetupResult :
303
+ async def setup (* , worker : Worker ) -> SetupResult :
322
304
logs = []
323
305
status = None
324
306
started_at = datetime .now (tz = timezone .utc )
325
307
326
308
try :
309
+ # will be async
327
310
for event in worker .setup ():
311
+ await asyncio .sleep (0 )
328
312
if isinstance (event , Log ):
329
313
logs .append (event .message )
330
314
elif isinstance (event , Done ):
@@ -354,19 +338,19 @@ def setup(*, worker: Worker) -> SetupResult:
354
338
)
355
339
356
340
357
- def predict (
341
+ async def predict (
358
342
* ,
359
343
worker : Worker ,
360
344
request : schema .PredictionRequest ,
361
345
event_handler : PredictionEventHandler ,
362
- should_cancel : threading .Event ,
346
+ should_cancel : asyncio .Event ,
363
347
) -> schema .PredictionResponse :
364
348
# Set up logger context within prediction thread.
365
349
structlog .contextvars .clear_contextvars ()
366
350
structlog .contextvars .bind_contextvars (prediction_id = request .id )
367
351
368
352
try :
369
- return _predict (
353
+ return await _predict (
370
354
worker = worker ,
371
355
request = request ,
372
356
event_handler = event_handler ,
@@ -379,12 +363,12 @@ def predict(
379
363
raise
380
364
381
365
382
- def _predict (
366
+ async def _predict (
383
367
* ,
384
368
worker : Worker ,
385
369
request : schema .PredictionRequest ,
386
370
event_handler : PredictionEventHandler ,
387
- should_cancel : threading .Event ,
371
+ should_cancel : asyncio .Event ,
388
372
) -> schema .PredictionResponse :
389
373
initial_prediction = request .dict ()
390
374
@@ -408,7 +392,9 @@ def _predict(
408
392
log .warn ("Failed to download url path from input" , exc_info = True )
409
393
return event_handler .response
410
394
395
+ # will be async
411
396
for event in worker .predict (input_dict , poll = 0.1 ):
397
+ await asyncio .sleep (0 )
412
398
if should_cancel .is_set ():
413
399
worker .cancel ()
414
400
should_cancel .clear ()
0 commit comments