1
1
import asyncio
2
+ import contextlib
2
3
import inspect
3
4
import multiprocessing
4
5
import os
8
9
import types
9
10
from enum import Enum , auto , unique
10
11
from multiprocessing .connection import Connection
11
- from typing import Any , Dict , Iterable , Optional , TextIO , Union
12
+ from typing import Any , Dict , Iterator , Optional , TextIO , Union
12
13
13
14
from ..json import make_encodeable
14
15
from ..predictor import (
@@ -58,7 +59,7 @@ def __init__(self, predictor_ref: str, tee_output: bool = True) -> None:
58
59
self ._child = _ChildWorker (predictor_ref , child_events , tee_output )
59
60
self ._terminating = False
60
61
61
- def setup (self ) -> Iterable [_PublicEventType ]:
62
+ def setup (self ) -> Iterator [_PublicEventType ]:
62
63
self ._assert_state (WorkerState .NEW )
63
64
self ._state = WorkerState .STARTING
64
65
self ._child .start ()
@@ -67,7 +68,7 @@ def setup(self) -> Iterable[_PublicEventType]:
67
68
68
69
def predict (
69
70
self , payload : Dict [str , Any ], poll : Optional [float ] = None
70
- ) -> Iterable [_PublicEventType ]:
71
+ ) -> Iterator [_PublicEventType ]:
71
72
self ._assert_state (WorkerState .READY )
72
73
self ._state = WorkerState .PROCESSING
73
74
self ._allow_cancel = True
@@ -108,7 +109,7 @@ def _assert_state(self, state: WorkerState) -> None:
108
109
109
110
def _wait (
110
111
self , poll : Optional [float ] = None , raise_on_error : Optional [str ] = None
111
- ) -> Iterable [_PublicEventType ]:
112
+ ) -> Iterator [_PublicEventType ]:
112
113
done = None
113
114
114
115
if poll :
@@ -178,15 +179,21 @@ def run(self) -> None:
178
179
[ws_stdout , ws_stderr ], self ._stream_write_hook
179
180
)
180
181
self ._stream_redirector .start ()
181
-
182
182
self ._setup ()
183
- asyncio . run ( self ._loop () )
183
+ self ._loop ()
184
184
self ._stream_redirector .shutdown ()
185
185
186
186
def _setup (self ) -> None :
187
- done = Done ()
188
- try :
187
+ with self . _handle_setup_error ():
188
+ # we need to load the predictor to know if setup is async
189
189
self ._predictor = load_predictor_from_ref (self ._predictor_ref )
190
+ # if users want to access the same event loop from setup and predict,
191
+ # both have to be async. if setup isn't async, it doesn't matter if we
192
+ # create the event loop here or after setup
193
+ #
194
+ # otherwise, if setup is sync and the user does new_event_loop to use a ClientSession,
195
+ # then tries to use the same session from async predict, they would get an error.
196
+ # that's significant if connections are open and would need to be discarded
190
197
if is_async_predictor (self ._predictor ):
191
198
self .loop = get_loop ()
192
199
# Could be a function or a class
@@ -195,6 +202,12 @@ def _setup(self) -> None:
195
202
self .loop .run_until_complete (run_setup_async (self ._predictor ))
196
203
else :
197
204
run_setup (self ._predictor )
205
+
206
+ @contextlib .contextmanager
207
+ def _handle_setup_error (self ) -> Iterator [None ]:
208
+ done = Done ()
209
+ try :
210
+ yield
198
211
except Exception as e :
199
212
traceback .print_exc ()
200
213
done .error = True
@@ -210,50 +223,76 @@ def _setup(self) -> None:
210
223
self ._stream_redirector .drain ()
211
224
self ._events .send (done )
212
225
213
- async def _loop (self ) -> None :
226
+ def _loop_sync (self ) -> None :
214
227
while True :
215
228
ev = self ._events .recv ()
216
229
if isinstance (ev , Shutdown ):
217
230
break
218
231
if isinstance (ev , PredictionInput ):
219
- await self ._predict (ev .payload )
232
+ self ._predict_sync (ev .payload )
220
233
else :
221
234
print (f"Got unexpected event: { ev } " , file = sys .stderr )
222
235
223
- async def _predict (self , payload : Dict [str , Any ]) -> None :
236
+ async def _loop_async (self ) -> None :
237
+ while True :
238
+ ev = self ._events .recv ()
239
+ if isinstance (ev , Shutdown ):
240
+ break
241
+ if isinstance (ev , PredictionInput ):
242
+ await self ._predict_async (ev .payload )
243
+ else :
244
+ print (f"Got unexpected event: { ev } " , file = sys .stderr )
245
+
246
+ def _loop (self ) -> None :
247
+ if is_async (get_predict (self ._predictor )):
248
+ self .loop .run_until_complete (self ._loop_async ())
249
+ else :
250
+ self ._loop_sync ()
251
+
252
+ @contextlib .contextmanager
253
+ def _handle_predict_error (self ) -> Iterator [None ]:
224
254
assert self ._predictor
225
255
done = Done ()
226
256
self ._cancelable = True
227
257
try :
258
+ yield
259
+ except CancelationException :
260
+ done .canceled = True
261
+ except Exception as e :
262
+ traceback .print_exc ()
263
+ done .error = True
264
+ done .error_detail = str (e )
265
+ finally :
266
+ self ._cancelable = False
267
+ self ._stream_redirector .drain ()
268
+ self ._events .send (done )
269
+
270
+ async def _predict_async (self , payload : Dict [str , Any ]) -> None :
271
+ with self ._handle_predict_error ():
228
272
predict = get_predict (self ._predictor )
229
273
result = predict (** payload )
230
-
231
274
if result :
232
275
if inspect .isasyncgen (result ):
233
276
self ._events .send (PredictionOutputType (multi = True ))
234
277
async for r in result :
235
278
self ._events .send (PredictionOutput (payload = make_encodeable (r )))
236
- elif inspect .isgenerator (result ):
237
- self ._events .send (PredictionOutputType (multi = True ))
238
- for r in result :
239
- self ._events .send (PredictionOutput (payload = make_encodeable (r )))
240
279
elif inspect .isawaitable (result ):
241
280
result = await result
242
281
self ._events .send (PredictionOutputType (multi = False ))
243
282
self ._events .send (PredictionOutput (payload = make_encodeable (result )))
283
+
284
+ def _predict_sync (self , payload : Dict [str , Any ]) -> None :
285
+ with self ._handle_predict_error ():
286
+ predict = get_predict (self ._predictor )
287
+ result = predict (** payload )
288
+ if result :
289
+ if inspect .isgenerator (result ):
290
+ self ._events .send (PredictionOutputType (multi = True ))
291
+ for r in result :
292
+ self ._events .send (PredictionOutput (payload = make_encodeable (r )))
244
293
else :
245
294
self ._events .send (PredictionOutputType (multi = False ))
246
295
self ._events .send (PredictionOutput (payload = make_encodeable (result )))
247
- except CancelationException :
248
- done .canceled = True
249
- except Exception as e :
250
- traceback .print_exc ()
251
- done .error = True
252
- done .error_detail = str (e )
253
- finally :
254
- self ._cancelable = False
255
- self ._stream_redirector .drain ()
256
- self ._events .send (done )
257
296
258
297
def _signal_handler (self , signum : int , frame : Optional [types .FrameType ]) -> None :
259
298
if signum == signal .SIGUSR1 and self ._cancelable :
@@ -270,13 +309,16 @@ def _stream_write_hook(
270
309
271
310
def get_loop () -> asyncio .AbstractEventLoop :
272
311
try :
312
+ # just in case something else created an event loop already
273
313
return asyncio .get_running_loop ()
274
314
except RuntimeError :
275
315
return asyncio .new_event_loop ()
276
316
277
317
318
+ def is_async (fn : Any ) -> bool :
319
+ return inspect .iscoroutinefunction (fn ) or inspect .isasyncgenfunction (fn )
320
+
321
+
278
322
def is_async_predictor (predictor : BasePredictor ) -> bool :
279
- predict = get_predict (predictor )
280
- if inspect .iscoroutinefunction (predict ) or inspect .isasyncgenfunction (predict ):
281
- return True
282
- return inspect .iscoroutinefunction (getattr (predictor , "setup" , None ))
323
+ setup = getattr (predictor , "setup" , None )
324
+ return is_async (setup ) or is_async (get_predict (predictor ))
0 commit comments