Skip to content

Commit 8d83783

Browse files
committed
Mux prediction events (#1405)
* race utility for racing awaitables * start mux, tag events with id, read pipe in a task, get events from mux * use async pipe for async child loop * _shutting_down vs _terminating * race with shutdown event * keep reading events during shutdown, but call terminate after the last Done * emit heartbeats from mux.read * don't use _wait. instead, setup reads event from the mux too * worker semaphore and prediction ctx * where _wait used to raise a fatal error, have _read_events set an error on Mux, and then Mux.read can raise the error in the right context. otherwise, the exception is stuck in a task and doesn't propagate correctly * fix event loop errors for <3.9 * keep track of predictions in flight explicitly and use that to route logs * don't wait for executor shutdown * progress: check for cancelation in task done_handler * let mux check if child is alive and set mux shutdown after leaving read event loop * close pipe when exiting * predict requires IDLE or PROCESSING * try adding a BUSY state distinct from PROCESSING when we no longer have capacity * move resetting events to setup() instead of _read_events() previously this was in _read_events because it's a coroutine that will have the correct event loop. however, _read_events actually gets created in a task, which can run *after* the first mux.read call by setup. since setup is now the first async entrypoint in worker and in tests, we can safely move it there * state_from_predictions_in_flight instead of checking the value of semaphore * make prediction_ctx "private" Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 4b960e1 commit 8d83783

File tree

4 files changed

+258
-78
lines changed

4 files changed

+258
-78
lines changed

python/cog/server/eventtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import secrets
12
from typing import Any, Dict
23

34
from attrs import define, field, validators
@@ -8,6 +9,7 @@
89
@define
910
class PredictionInput:
1011
payload: Dict[str, Any]
12+
id: str = field(factory=lambda: secrets.token_hex(4))
1113

1214

1315
@define

python/cog/server/helpers.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import io
44
import os
55
import selectors
6+
import sys
67
import threading
78
import uuid
89
from multiprocessing.connection import Connection
910
from typing import (
1011
Any,
1112
Callable,
13+
Coroutine,
1214
Generic,
1315
Optional,
1416
Sequence,
1517
TextIO,
1618
TypeVar,
19+
Union,
1720
)
1821

1922

@@ -160,13 +163,44 @@ def run(self) -> None:
160163
self.drain_event.set()
161164
drain_tokens_seen = 0
162165

166+
163167
X = TypeVar("X")
168+
Y = TypeVar("Y")
169+
170+
171+
async def race(
172+
x: Coroutine[None, None, X],
173+
y: Coroutine[None, None, Y],
174+
timeout: Optional[float] = None,
175+
) -> Union[X, Y]:
176+
tasks = [asyncio.create_task(x), asyncio.create_task(y)]
177+
wait = asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
178+
done, pending = await wait
179+
for task in pending:
180+
task.cancel()
181+
if not done:
182+
raise TimeoutError
183+
# done is an unordered set but we want to preserve original order
184+
result_task, *others = (t for t in tasks if t in done)
185+
# during shutdown, some of the other completed tasks might be an error
186+
# cancel them instead of handling the error to avoid the warning
187+
# "Task exception was never retrieved"
188+
for task in others:
189+
msg = "was completed at the same time as another selected task, canceling"
190+
# FIXME: ues a logger?
191+
print(task, msg, file=sys.stderr)
192+
task.cancel()
193+
return result_task.result()
194+
164195

165196
# functionally this is the exact same thing as aioprocessing but 0.1% the code
166197
# however it's still worse than just using actual asynchronous io
167198
class AsyncPipe(Generic[X]):
168-
def __init__(self, conn: Connection) -> None:
199+
def __init__(
200+
self, conn: Connection, alive: Callable[[], bool] = lambda: True
201+
) -> None:
169202
self.conn = conn
203+
self.alive = alive
170204
self.exiting = threading.Event()
171205
self.executor = concurrent.futures.ThreadPoolExecutor(1)
172206

@@ -175,7 +209,7 @@ def send(self, obj: Any) -> None:
175209

176210
def shutdown(self) -> None:
177211
self.exiting.set()
178-
self.executor.shutdown(wait=False)
212+
self.executor.shutdown(wait=True)
179213
# if we ever need cancel_futures (introduced 3.9), we can copy it in from
180214
# https://github.com/python/cpython/blob/3.11/Lib/concurrent/futures/thread.py#L216-L235
181215

@@ -185,12 +219,20 @@ def poll(self, timeout: float = 0.0) -> bool:
185219
def _recv(self) -> Optional[X]:
186220
# this ugly mess could easily be avoided with loop.connect_read_pipe
187221
# even loop.add_reader would help but we don't want to mess with a thread-local loop
188-
while not self.exiting.is_set():
222+
while not self.exiting.is_set() and not self.conn.closed and self.alive():
189223
if self.conn.poll(0.01):
224+
if self.conn.closed or not self.alive():
225+
print("caught conn closed or unalive")
226+
return
190227
return self.conn.recv()
191228
return None
192229

193230
async def coro_recv(self) -> Optional[X]:
194231
loop = asyncio.get_running_loop()
195232
# believe it or not this can still deadlock!
196233
return await loop.run_in_executor(self.executor, self._recv)
234+
235+
async def coro_recv_with_exit(self, exit: asyncio.Event) -> Optional[X]:
236+
result = await race(self.coro_recv(), exit.wait())
237+
if result is not True: # wait() would return True
238+
return result

0 commit comments

Comments
 (0)