Skip to content

Commit 90f3f07

Browse files
committed
maybe fix the race condition
1 parent affb023 commit 90f3f07

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

python/cog/server/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ async def async_init(self) -> None:
2626
if self.started:
2727
return
2828
fd = self.wrapped_conn.fileno()
29-
# # mp may have handled something already but let's dup so exit is clean
29+
# mp may have handled something already but let's dup so exit is clean
3030
dup_fd = os.dup(fd)
3131
sock = socket.socket(fileno=dup_fd)
3232
# sock = socket.socket(fileno=fd)
3333
# we don't want to see EAGAIN, we'd rather wait
34+
# however, perhaps this is wrong and in some cases this could still block terribly
3435
# sock.setblocking(False)
3536
# TODO: use /proc/sys/net/core/rmem_max, but special-case language models
3637
sz = 65536

python/cog/server/helpers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def drain(self) -> None:
114114
stream.flush()
115115
debug("wait drain")
116116
if not self.drain_event.wait(timeout=1):
117+
debug("drain timed out")
117118
raise RuntimeError("output streams failed to drain")
118119
debug("drain done")
119120

@@ -154,12 +155,12 @@ def run(self) -> None:
154155
stream = key.data
155156

156157
for line in stream.wrapped:
158+
debug("redirector saw", line)
157159
if not line.endswith("\n"):
158160
# TODO: limit how much we're prepared to buffer on a
159161
# single line
160162
buffers[stream.name].write(line)
161163
continue
162-
debug("redirector saw", line)
163164

164165
full_line = buffers[stream.name].getvalue() + line.strip()
165166

@@ -179,7 +180,9 @@ def run(self) -> None:
179180
# thing in the line was a drain token (or a terminate
180181
# token).
181182
if full_line:
183+
debug("write hook")
182184
self._write_hook(stream.name, stream.original, full_line + "\n")
185+
debug("write hook done")
183186

184187
if drain_tokens_seen >= drain_tokens_needed:
185188
debug("drain event set")
@@ -202,14 +205,20 @@ async def switch_to_async(self) -> None:
202205
"""
203206
debug("switch async, drain")
204207
# Drain the streams to ensure all buffered data is processed
205-
self.drain()
208+
try:
209+
self.drain()
210+
except RuntimeError:
211+
debug("drain failed")
212+
raise
213+
debug("drain done, shutdown")
206214

207215
# Shut down the thread
208216
# we do this before starting a coroutine that will also read from the same fd
209217
# so that shutdown can find the terminate tokens correctly
210218
self.shutdown()
211219
self.stream_tasks = []
212220
self.is_async = True
221+
debug("set is async")
213222

214223
for stream in self._streams:
215224
# Open each stream as a StreamReader
@@ -220,6 +229,9 @@ async def switch_to_async(self) -> None:
220229
task = asyncio.create_task(self.process_stream(stream, reader))
221230
self.stream_tasks.append(task)
222231

232+
# give the tasks a chance to start
233+
await asyncio.sleep(0)
234+
223235
async def process_stream(
224236
self, stream: WrappedStream, reader: asyncio.StreamReader
225237
) -> None:

python/cog/server/worker.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,12 @@ async def _async_init(self) -> None:
158158
if self._events_async:
159159
debug("async_init finished")
160160
return
161+
# if AsyncConnection is created before switch_to_async, a race condition can cause drain to fail
162+
# and write, seemingly, to block
163+
# maybe because we're trying to call StreamWriter.write when no event loop is running?
164+
await self._stream_redirector.switch_to_async()
161165
self._events_async = AsyncConnection(self._events)
162166
await self._events_async.async_init()
163-
await self._stream_redirector.switch_to_async()
164167
debug("async_init done")
165168

166169
def _setup(self) -> None:
@@ -189,12 +192,15 @@ def _setup(self) -> None:
189192
debug("inspect")
190193
if inspect.iscoroutinefunction(self._predictor.setup):
191194
# we should probably handle Shutdown during this process?
192-
# debug("creating AsyncConn")
195+
# possibly we prefer to not stop-start the event loop
196+
# between these calls
193197
self.loop.run_until_complete(self._async_init())
194198
self.loop.run_until_complete(run_setup_async(self._predictor))
199+
debug("run_setup_async done")
195200
else:
196201
debug("sync setup")
197202
run_setup(self._predictor)
203+
debug("_setup done inside ctx mgr")
198204
debug("_setup done")
199205

200206
@contextlib.contextmanager
@@ -217,8 +223,15 @@ def _handle_setup_error(self) -> Iterator[None]:
217223
done.error_detail = str(e)
218224
raise
219225
finally:
226+
# we can arrive here if there was an error setting up stream_redirector
227+
# for example, because drain failed
228+
# in this case this drain could block or fail
220229
debug("setup done, calling drain")
221-
self._stream_redirector.drain()
230+
try:
231+
self._stream_redirector.drain()
232+
except Exception as e:
233+
debug("exc", str(e))
234+
raise
222235
debug("sending setup done")
223236
self.send(("SETUP", done))
224237
debug("sent setup done")
@@ -240,6 +253,7 @@ def _loop_sync(self) -> None:
240253
self._stream_redirector.shutdown()
241254

242255
async def _loop_async(self) -> None:
256+
debug("loop async")
243257
await self._async_init()
244258
assert self._events_async
245259
tasks: dict[str, asyncio.Task[None]] = {}
@@ -306,10 +320,14 @@ def _emit_metric(self, name: str, value: "int | float") -> None:
306320

307321
def send(self, obj: Any) -> None:
308322
if self._events_async:
323+
debug("sending on async")
309324
self._events_async.send(obj)
325+
debug("sent on async")
310326
else:
327+
debug("send lock")
311328
with self._sync_events_lock:
312329
self._events.send(obj)
330+
debug("finished sync send")
313331

314332
def _mk_send(self, id: str) -> Callable[[PublicEventType], None]:
315333
def send(event: PublicEventType) -> None:

0 commit comments

Comments
 (0)