1
1
import asyncio
2
+ import contextlib
3
+ import logging
4
+ import multiprocessing
5
+ import os
6
+ import signal
7
+ import sys
2
8
import threading
3
9
import traceback
4
10
import typing # TypeAlias, py3.10
5
11
from datetime import datetime , timezone
6
- from typing import Any , AsyncIterator , Optional , Union
12
+ from enum import Enum , auto , unique
13
+ from typing import Any , AsyncIterator , Iterator , Optional , Union
7
14
8
15
import httpx
9
16
import structlog
13
20
from .. import schema , types
14
21
from .clients import SKIP_START_EVENT , ClientManager
15
22
from .eventtypes import (
23
+ Cancel ,
16
24
Done ,
17
25
Heartbeat ,
18
26
Log ,
19
27
PredictionInput ,
20
28
PredictionOutput ,
21
29
PredictionOutputType ,
22
30
PublicEventType ,
31
+ Shutdown ,
23
32
)
33
+ from .exceptions import (
34
+ FatalWorkerException ,
35
+ InvalidStateException ,
36
+ )
37
+ from .helpers import AsyncPipe
24
38
from .probes import ProbeHelper
25
- from .worker import Worker
39
+ from .worker import Mux , _ChildWorker
26
40
27
41
log = structlog .get_logger ("cog.server.runner" )
42
+ _spawn = multiprocessing .get_context ("spawn" )
28
43
29
44
30
45
class FileUploadError (Exception ):
@@ -39,6 +54,16 @@ class UnknownPredictionError(Exception):
39
54
pass
40
55
41
56
57
+ @unique
58
+ class WorkerState (Enum ):
59
+ NEW = auto ()
60
+ STARTING = auto ()
61
+ IDLE = auto ()
62
+ PROCESSING = auto ()
63
+ BUSY = auto ()
64
+ DEFUNCT = auto ()
65
+
66
+
42
67
@define
43
68
class SetupResult :
44
69
started_at : datetime
@@ -62,9 +87,8 @@ def __init__(
62
87
shutdown_event : Optional [threading .Event ],
63
88
upload_url : Optional [str ] = None ,
64
89
concurrency : int = 1 ,
90
+ tee_output : bool = True ,
65
91
) -> None :
66
- self ._worker = Worker (predictor_ref = predictor_ref , concurrency = concurrency )
67
-
68
92
# __main__ waits for this event
69
93
self ._shutdown_event = shutdown_event
70
94
self ._upload_url = upload_url
@@ -73,9 +97,28 @@ def __init__(
73
97
)
74
98
self .client_manager = ClientManager () # upload_url)
75
99
100
+ # worker code
101
+ self ._state = WorkerState .NEW
102
+ self ._semaphore = asyncio .Semaphore (concurrency )
103
+ self ._concurrency = concurrency
104
+
105
+ # A pipe with which to communicate with the child worker.
106
+ events , child_events = _spawn .Pipe ()
107
+ self ._child = _ChildWorker (predictor_ref , child_events , tee_output )
108
+ self ._events : "AsyncPipe[tuple[str, PublicEventType]]" = AsyncPipe (
109
+ events , self ._child .is_alive
110
+ )
111
+ # shutdown requested
112
+ self ._shutting_down = False
113
+ # stop reading events
114
+ self ._terminating = asyncio .Event ()
115
+ self ._mux = Mux (self ._terminating )
116
+ self ._predictions_in_flight = set ()
117
+
76
118
def setup (self ) -> SetupTask :
77
- if not self ._worker . setup_is_allowed () :
119
+ if self ._state != WorkerState . NEW :
78
120
raise RunnerBusyError
121
+ self ._state = WorkerState .STARTING
79
122
80
123
# app is allowed to respond to requests and poll the state of this task
81
124
# while it is running
@@ -84,16 +127,27 @@ async def inner() -> SetupResult:
84
127
status = None
85
128
started_at = datetime .now (tz = timezone .utc )
86
129
130
+ # in 3.10 Event started doing get_running_loop
131
+ # previously it stored the loop when created, which causes an error in tests
132
+ if sys .version_info < (3 , 10 ):
133
+ self ._terminating = self ._mux .terminating = asyncio .Event ()
134
+
135
+ self ._child .start ()
136
+ self ._ensure_event_reader ()
137
+
87
138
try :
88
- async for event in self ._worker . setup ( ):
139
+ async for event in self ._mux . read ( "SETUP" , poll = 0.1 ):
89
140
if isinstance (event , Log ):
90
141
logs .append (event .message )
91
142
elif isinstance (event , Done ):
92
- status = (
93
- schema .Status .FAILED
94
- if event .error
95
- else schema .Status .SUCCEEDED
96
- )
143
+ if event .error :
144
+ raise FatalWorkerException (
145
+ "Predictor errored during setup: " + event .error_detail
146
+ )
147
+ status = schema .Status .FAILED
148
+ else :
149
+ status = schema .Status .SUCCEEDED
150
+ self ._state = WorkerState .IDLE
97
151
except Exception :
98
152
logs .append (traceback .format_exc ())
99
153
status = schema .Status .FAILED
@@ -134,10 +188,49 @@ def handle_error(task: RunnerTask) -> None:
134
188
result .add_done_callback (handle_error )
135
189
return result
136
190
191
+ def state_from_predictions_in_flight (self ) -> WorkerState :
192
+ valid_states = {WorkerState .IDLE , WorkerState .PROCESSING , WorkerState .BUSY }
193
+ if self ._state not in valid_states :
194
+ raise InvalidStateException (
195
+ f"Invalid operation: state is { self ._state } (must be IDLE, PROCESSING, or BUSY)"
196
+ )
197
+ if len (self ._predictions_in_flight ) == self ._concurrency :
198
+ return WorkerState .BUSY
199
+ if len (self ._predictions_in_flight ) == 0 :
200
+ return WorkerState .IDLE
201
+ return WorkerState .PROCESSING
202
+
203
+ def is_busy (self ) -> bool :
204
+ return self ._state not in {WorkerState .PROCESSING , WorkerState .IDLE }
205
+
206
+ def enter_predict (self , id : str ) -> None :
207
+ if self .is_busy ():
208
+ raise InvalidStateException (
209
+ f"Invalid operation: state is { self ._state } (must be processing or idle)"
210
+ )
211
+ if self ._shutting_down :
212
+ raise InvalidStateException (
213
+ "cannot accept new predictions because shutdown requested"
214
+ )
215
+ self ._predictions_in_flight .add (id )
216
+ self ._state = self .state_from_predictions_in_flight ()
217
+
218
+ def exit_predict (self , id : str ) -> None :
219
+ self ._predictions_in_flight .remove (id )
220
+ self ._state = self .state_from_predictions_in_flight ()
221
+
222
+ @contextlib .contextmanager
223
+ def prediction_ctx (self , id : str ) -> Iterator [None ]:
224
+ self .enter_predict (id )
225
+ try :
226
+ yield
227
+ finally :
228
+ self .exit_predict (id )
229
+
137
230
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
138
231
# no longer have to support Python 3.8
139
232
def predict (
140
- self , request : schema .PredictionRequest
233
+ self , request : schema .PredictionRequest , poll : Optional [ float ] = None
141
234
) -> "tuple[schema.PredictionResponse, PredictionTask]" :
142
235
if self .is_busy ():
143
236
if request .id in self ._predictions :
@@ -157,10 +250,10 @@ def predict(
157
250
response = event_handler .response
158
251
159
252
prediction_input = PredictionInput .from_request (request )
160
- predict_ctx = self ._worker .good_predict (prediction_input , poll = 0.1 )
253
+ # # predict_ctx = self._worker.good_predict(prediction_input, poll=0.1)
161
254
# accept work and change state to get the future event stream,
162
255
# but don't enter it yet
163
- event_stream = predict_ctx . __enter__ ( )
256
+ self . enter_predict ( request . id )
164
257
# alternative: self._worker.enter_predict(request.id)
165
258
166
259
# what if instead we raised parts of worker instead of trying to access private methods?
@@ -179,8 +272,11 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
179
272
if isinstance (v , types .URLTempFile ):
180
273
real_path = await v .convert (self .client_manager .download_client )
181
274
prediction_input .payload [k ] = real_path
182
- result = await event_handler .handle_event_stream (event_stream )
183
- return result
275
+ async with self ._semaphore :
276
+ self ._events .send (prediction_input )
277
+ event_stream = self ._mux .read (prediction_input .id , poll = poll )
278
+ result = await event_handler .handle_event_stream (event_stream )
279
+ return result
184
280
except httpx .HTTPError as e :
185
281
tb = traceback .format_exc ()
186
282
await event_handler .append_logs (tb )
@@ -200,7 +296,7 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
200
296
# mark the prediction as done and update state
201
297
# ... actually, we might want to mark that part earlier
202
298
# even if we're still upload files we can accept new work
203
- predict_ctx . __exit__ ( None , None , None )
299
+ self . exit_predict ( prediction_input . id )
204
300
# FIXME: use isinstance(BaseInput)
205
301
if hasattr (request .input , "cleanup" ):
206
302
request .input .cleanup () # type: ignore
@@ -216,22 +312,78 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
216
312
217
313
return (response , result )
218
314
219
- def is_busy (self ) -> bool :
220
- return self ._worker .is_busy ()
221
-
222
315
def shutdown (self ) -> None :
316
+ if self ._state == WorkerState .DEFUNCT :
317
+ return
318
+ # shutdown requested, but keep reading events
319
+ self ._shutting_down = True
320
+
321
+ if self ._child .is_alive ():
322
+ self ._events .send (Shutdown ())
323
+
324
+ def terminate (self ) -> None :
223
325
for _ , task in self ._predictions .values ():
224
326
task .cancel ()
225
- self ._worker .terminate ()
327
+ if self ._state == WorkerState .DEFUNCT :
328
+ return
329
+
330
+ self ._terminating .set ()
331
+ self ._state = WorkerState .DEFUNCT
332
+
333
+ if self ._child .is_alive ():
334
+ self ._child .terminate ()
335
+ self ._child .join ()
336
+ self ._events .shutdown ()
337
+ if self ._read_events_task :
338
+ self ._read_events_task .cancel ()
226
339
227
340
def cancel (self , prediction_id : str ) -> None :
228
- try :
229
- self ._worker .cancel (prediction_id )
230
- # if the runner is in an invalid state, predictions_in_flight would just be empty
231
- # and it's a keyerror anyway
232
- except KeyError as e :
233
- print (e )
234
- raise UnknownPredictionError () from e
341
+ if id not in self ._predictions_in_flight :
342
+ print ("id not there" , prediction_id , self ._predictions_in_flight )
343
+ raise UnknownPredictionError ()
344
+ if self ._child .is_alive () and self ._child .pid is not None :
345
+ os .kill (self ._child .pid , signal .SIGUSR1 )
346
+ print ("sent cancel" )
347
+ self ._events .send (Cancel (prediction_id ))
348
+ # maybe this should probably check self._semaphore._value == self._concurrent
349
+
350
+ _read_events_task : "Optional[asyncio.Task[None]]" = None
351
+
352
+ def _ensure_event_reader (self ) -> None :
353
+ def handle_error (task : "asyncio.Task[None]" ) -> None :
354
+ if task .cancelled ():
355
+ return
356
+ exc = task .exception ()
357
+ if exc :
358
+ logging .error ("caught exception" , exc_info = exc )
359
+
360
+ if not self ._read_events_task :
361
+ self ._read_events_task = asyncio .create_task (self ._read_events ())
362
+ self ._read_events_task .add_done_callback (handle_error )
363
+
364
+ async def _read_events (self ) -> None :
365
+ while self ._child .is_alive () and not self ._terminating .is_set ():
366
+ # this can still be running when the task is destroyed
367
+ result = await self ._events .coro_recv_with_exit (self ._terminating )
368
+ print ("reader got" , result )
369
+ if result is None : # event loop closed or child died
370
+ break
371
+ id , event = result
372
+ if id == "LOG" and self ._state == WorkerState .STARTING :
373
+ id = "SETUP"
374
+ if id == "LOG" and len (self ._predictions_in_flight ) == 1 :
375
+ id = list (self ._predictions_in_flight )[0 ]
376
+ await self ._mux .write (id , event )
377
+ # If we dropped off the end off the end of the loop, check if it's
378
+ # because the child process died.
379
+ if not self ._child .is_alive () and not self ._terminating .is_set ():
380
+ exitcode = self ._child .exitcode
381
+ self ._mux .fatal = FatalWorkerException (
382
+ f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode { exitcode } )"
383
+ )
384
+ # this is the same event as _terminating
385
+ # we need to set it so mux.reads wake up and throw an error if needed
386
+ self ._mux .terminating .set ()
235
387
236
388
237
389
class PredictionEventHandler :
0 commit comments