Skip to content

Commit 974675f

Browse files
committed
glom worker into runner
Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 8bd2fbc commit 974675f

File tree

3 files changed

+181
-231
lines changed

3 files changed

+181
-231
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ dependencies = [
2929
optional-dependencies = { "dev" = [
3030
"black",
3131
"build",
32-
"httpx",
3332
'hypothesis<6.80.0; python_version < "3.8"',
3433
'hypothesis; python_version >= "3.8"',
34+
"respx",
3535
'numpy<1.22.0; python_version < "3.8"',
3636
'numpy; python_version >= "3.8"',
3737
"pillow",

python/cog/server/runner.py

Lines changed: 180 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import asyncio
2+
import contextlib
3+
import logging
4+
import multiprocessing
5+
import os
6+
import signal
7+
import sys
28
import threading
39
import traceback
410
import typing # TypeAlias, py3.10
511
from datetime import datetime, timezone
6-
from typing import Any, AsyncIterator, Optional, Union
12+
from enum import Enum, auto, unique
13+
from typing import Any, AsyncIterator, Iterator, Optional, Union
714

815
import httpx
916
import structlog
@@ -13,18 +20,26 @@
1320
from .. import schema, types
1421
from .clients import SKIP_START_EVENT, ClientManager
1522
from .eventtypes import (
23+
Cancel,
1624
Done,
1725
Heartbeat,
1826
Log,
1927
PredictionInput,
2028
PredictionOutput,
2129
PredictionOutputType,
2230
PublicEventType,
31+
Shutdown,
2332
)
33+
from .exceptions import (
34+
FatalWorkerException,
35+
InvalidStateException,
36+
)
37+
from .helpers import AsyncPipe
2438
from .probes import ProbeHelper
25-
from .worker import Worker
39+
from .worker import Mux, _ChildWorker
2640

2741
log = structlog.get_logger("cog.server.runner")
42+
_spawn = multiprocessing.get_context("spawn")
2843

2944

3045
class FileUploadError(Exception):
@@ -39,6 +54,16 @@ class UnknownPredictionError(Exception):
3954
pass
4055

4156

57+
@unique
58+
class WorkerState(Enum):
59+
NEW = auto()
60+
STARTING = auto()
61+
IDLE = auto()
62+
PROCESSING = auto()
63+
BUSY = auto()
64+
DEFUNCT = auto()
65+
66+
4267
@define
4368
class SetupResult:
4469
started_at: datetime
@@ -62,9 +87,8 @@ def __init__(
6287
shutdown_event: Optional[threading.Event],
6388
upload_url: Optional[str] = None,
6489
concurrency: int = 1,
90+
tee_output: bool = True,
6591
) -> None:
66-
self._worker = Worker(predictor_ref=predictor_ref, concurrency=concurrency)
67-
6892
# __main__ waits for this event
6993
self._shutdown_event = shutdown_event
7094
self._upload_url = upload_url
@@ -73,9 +97,28 @@ def __init__(
7397
)
7498
self.client_manager = ClientManager() # upload_url)
7599

100+
# worker code
101+
self._state = WorkerState.NEW
102+
self._semaphore = asyncio.Semaphore(concurrency)
103+
self._concurrency = concurrency
104+
105+
# A pipe with which to communicate with the child worker.
106+
events, child_events = _spawn.Pipe()
107+
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
108+
self._events: "AsyncPipe[tuple[str, PublicEventType]]" = AsyncPipe(
109+
events, self._child.is_alive
110+
)
111+
# shutdown requested
112+
self._shutting_down = False
113+
# stop reading events
114+
self._terminating = asyncio.Event()
115+
self._mux = Mux(self._terminating)
116+
self._predictions_in_flight = set()
117+
76118
def setup(self) -> SetupTask:
77-
if not self._worker.setup_is_allowed():
119+
if self._state != WorkerState.NEW:
78120
raise RunnerBusyError
121+
self._state = WorkerState.STARTING
79122

80123
# app is allowed to respond to requests and poll the state of this task
81124
# while it is running
@@ -84,16 +127,27 @@ async def inner() -> SetupResult:
84127
status = None
85128
started_at = datetime.now(tz=timezone.utc)
86129

130+
# in 3.10 Event started doing get_running_loop
131+
# previously it stored the loop when created, which causes an error in tests
132+
if sys.version_info < (3, 10):
133+
self._terminating = self._mux.terminating = asyncio.Event()
134+
135+
self._child.start()
136+
self._ensure_event_reader()
137+
87138
try:
88-
async for event in self._worker.setup():
139+
async for event in self._mux.read("SETUP", poll=0.1):
89140
if isinstance(event, Log):
90141
logs.append(event.message)
91142
elif isinstance(event, Done):
92-
status = (
93-
schema.Status.FAILED
94-
if event.error
95-
else schema.Status.SUCCEEDED
96-
)
143+
if event.error:
144+
raise FatalWorkerException(
145+
"Predictor errored during setup: " + event.error_detail
146+
)
147+
status = schema.Status.FAILED
148+
else:
149+
status = schema.Status.SUCCEEDED
150+
self._state = WorkerState.IDLE
97151
except Exception:
98152
logs.append(traceback.format_exc())
99153
status = schema.Status.FAILED
@@ -134,10 +188,49 @@ def handle_error(task: RunnerTask) -> None:
134188
result.add_done_callback(handle_error)
135189
return result
136190

191+
def state_from_predictions_in_flight(self) -> WorkerState:
192+
valid_states = {WorkerState.IDLE, WorkerState.PROCESSING, WorkerState.BUSY}
193+
if self._state not in valid_states:
194+
raise InvalidStateException(
195+
f"Invalid operation: state is {self._state} (must be IDLE, PROCESSING, or BUSY)"
196+
)
197+
if len(self._predictions_in_flight) == self._concurrency:
198+
return WorkerState.BUSY
199+
if len(self._predictions_in_flight) == 0:
200+
return WorkerState.IDLE
201+
return WorkerState.PROCESSING
202+
203+
def is_busy(self) -> bool:
204+
return self._state not in {WorkerState.PROCESSING, WorkerState.IDLE}
205+
206+
def enter_predict(self, id: str) -> None:
207+
if self.is_busy():
208+
raise InvalidStateException(
209+
f"Invalid operation: state is {self._state} (must be processing or idle)"
210+
)
211+
if self._shutting_down:
212+
raise InvalidStateException(
213+
"cannot accept new predictions because shutdown requested"
214+
)
215+
self._predictions_in_flight.add(id)
216+
self._state = self.state_from_predictions_in_flight()
217+
218+
def exit_predict(self, id: str) -> None:
219+
self._predictions_in_flight.remove(id)
220+
self._state = self.state_from_predictions_in_flight()
221+
222+
@contextlib.contextmanager
223+
def prediction_ctx(self, id: str) -> Iterator[None]:
224+
self.enter_predict(id)
225+
try:
226+
yield
227+
finally:
228+
self.exit_predict(id)
229+
137230
# TODO: Make the return type AsyncResult[schema.PredictionResponse] when we
138231
# no longer have to support Python 3.8
139232
def predict(
140-
self, request: schema.PredictionRequest
233+
self, request: schema.PredictionRequest, poll: Optional[float] = None
141234
) -> "tuple[schema.PredictionResponse, PredictionTask]":
142235
if self.is_busy():
143236
if request.id in self._predictions:
@@ -157,10 +250,10 @@ def predict(
157250
response = event_handler.response
158251

159252
prediction_input = PredictionInput.from_request(request)
160-
predict_ctx = self._worker.good_predict(prediction_input, poll=0.1)
253+
# # predict_ctx = self._worker.good_predict(prediction_input, poll=0.1)
161254
# accept work and change state to get the future event stream,
162255
# but don't enter it yet
163-
event_stream = predict_ctx.__enter__()
256+
self.enter_predict(request.id)
164257
# alternative: self._worker.enter_predict(request.id)
165258

166259
# what if instead we raised parts of worker instead of trying to access private methods?
@@ -179,8 +272,11 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
179272
if isinstance(v, types.URLTempFile):
180273
real_path = await v.convert(self.client_manager.download_client)
181274
prediction_input.payload[k] = real_path
182-
result = await event_handler.handle_event_stream(event_stream)
183-
return result
275+
async with self._semaphore:
276+
self._events.send(prediction_input)
277+
event_stream = self._mux.read(prediction_input.id, poll=poll)
278+
result = await event_handler.handle_event_stream(event_stream)
279+
return result
184280
except httpx.HTTPError as e:
185281
tb = traceback.format_exc()
186282
await event_handler.append_logs(tb)
@@ -200,7 +296,7 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
200296
# mark the prediction as done and update state
201297
# ... actually, we might want to mark that part earlier
202298
# even if we're still upload files we can accept new work
203-
predict_ctx.__exit__(None, None, None)
299+
self.exit_predict(prediction_input.id)
204300
# FIXME: use isinstance(BaseInput)
205301
if hasattr(request.input, "cleanup"):
206302
request.input.cleanup() # type: ignore
@@ -216,22 +312,78 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
216312

217313
return (response, result)
218314

219-
def is_busy(self) -> bool:
220-
return self._worker.is_busy()
221-
222315
def shutdown(self) -> None:
316+
if self._state == WorkerState.DEFUNCT:
317+
return
318+
# shutdown requested, but keep reading events
319+
self._shutting_down = True
320+
321+
if self._child.is_alive():
322+
self._events.send(Shutdown())
323+
324+
def terminate(self) -> None:
223325
for _, task in self._predictions.values():
224326
task.cancel()
225-
self._worker.terminate()
327+
if self._state == WorkerState.DEFUNCT:
328+
return
329+
330+
self._terminating.set()
331+
self._state = WorkerState.DEFUNCT
332+
333+
if self._child.is_alive():
334+
self._child.terminate()
335+
self._child.join()
336+
self._events.shutdown()
337+
if self._read_events_task:
338+
self._read_events_task.cancel()
226339

227340
def cancel(self, prediction_id: str) -> None:
228-
try:
229-
self._worker.cancel(prediction_id)
230-
# if the runner is in an invalid state, predictions_in_flight would just be empty
231-
# and it's a keyerror anyway
232-
except KeyError as e:
233-
print(e)
234-
raise UnknownPredictionError() from e
341+
if id not in self._predictions_in_flight:
342+
print("id not there", prediction_id, self._predictions_in_flight)
343+
raise UnknownPredictionError()
344+
if self._child.is_alive() and self._child.pid is not None:
345+
os.kill(self._child.pid, signal.SIGUSR1)
346+
print("sent cancel")
347+
self._events.send(Cancel(prediction_id))
348+
# maybe this should probably check self._semaphore._value == self._concurrent
349+
350+
_read_events_task: "Optional[asyncio.Task[None]]" = None
351+
352+
def _ensure_event_reader(self) -> None:
353+
def handle_error(task: "asyncio.Task[None]") -> None:
354+
if task.cancelled():
355+
return
356+
exc = task.exception()
357+
if exc:
358+
logging.error("caught exception", exc_info=exc)
359+
360+
if not self._read_events_task:
361+
self._read_events_task = asyncio.create_task(self._read_events())
362+
self._read_events_task.add_done_callback(handle_error)
363+
364+
async def _read_events(self) -> None:
365+
while self._child.is_alive() and not self._terminating.is_set():
366+
# this can still be running when the task is destroyed
367+
result = await self._events.coro_recv_with_exit(self._terminating)
368+
print("reader got", result)
369+
if result is None: # event loop closed or child died
370+
break
371+
id, event = result
372+
if id == "LOG" and self._state == WorkerState.STARTING:
373+
id = "SETUP"
374+
if id == "LOG" and len(self._predictions_in_flight) == 1:
375+
id = list(self._predictions_in_flight)[0]
376+
await self._mux.write(id, event)
377+
# If we dropped off the end off the end of the loop, check if it's
378+
# because the child process died.
379+
if not self._child.is_alive() and not self._terminating.is_set():
380+
exitcode = self._child.exitcode
381+
self._mux.fatal = FatalWorkerException(
382+
f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
383+
)
384+
# this is the same event as _terminating
385+
# we need to set it so mux.reads wake up and throw an error if needed
386+
self._mux.terminating.set()
235387

236388

237389
class PredictionEventHandler:

0 commit comments

Comments
 (0)