Skip to content

Commit e7a5ff4

Browse files
committed
maybe fix the race condition
1 parent 77956e4 commit e7a5ff4

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

python/cog/server/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ async def async_init(self) -> None:
2525
if self.started:
2626
return
2727
fd = self.wrapped_conn.fileno()
28-
# # mp may have handled something already but let's dup so exit is clean
28+
# mp may have handled something already but let's dup so exit is clean
2929
dup_fd = os.dup(fd)
3030
sock = socket.socket(fileno=dup_fd)
3131
# sock = socket.socket(fileno=fd)
3232
# we don't want to see EAGAIN, we'd rather wait
33+
# however, perhaps this is wrong and in some cases this could still block terribly
3334
# sock.setblocking(False)
3435
# TODO: use /proc/sys/net/core/rmem_max, but special-case language models
3536
sz = 65536

python/cog/server/helpers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def drain(self) -> None:
114114
stream.flush()
115115
debug("wait drain")
116116
if not self.drain_event.wait(timeout=1):
117+
debug("drain timed out")
117118
raise RuntimeError("output streams failed to drain")
118119
debug("drain done")
119120

@@ -154,12 +155,12 @@ def run(self) -> None:
154155
stream = key.data
155156

156157
for line in stream.wrapped:
158+
debug("redirector saw", line)
157159
if not line.endswith("\n"):
158160
# TODO: limit how much we're prepared to buffer on a
159161
# single line
160162
buffers[stream.name].write(line)
161163
continue
162-
debug("redirector saw", line)
163164

164165
full_line = buffers[stream.name].getvalue() + line.strip()
165166

@@ -179,7 +180,9 @@ def run(self) -> None:
179180
# thing in the line was a drain token (or a terminate
180181
# token).
181182
if full_line:
183+
debug("write hook")
182184
self._write_hook(stream.name, stream.original, full_line + "\n")
185+
debug("write hook done")
183186

184187
if drain_tokens_seen >= drain_tokens_needed:
185188
debug("drain event set")
@@ -202,14 +205,20 @@ async def switch_to_async(self) -> None:
202205
"""
203206
debug("switch async, drain")
204207
# Drain the streams to ensure all buffered data is processed
205-
self.drain()
208+
try:
209+
self.drain()
210+
except RuntimeError:
211+
debug("drain failed")
212+
raise
213+
debug("drain done, shutdown")
206214

207215
# Shut down the thread
208216
# we do this before starting a coroutine that will also read from the same fd
209217
# so that shutdown can find the terminate tokens correctly
210218
self.shutdown()
211219
self.stream_tasks = []
212220
self.is_async = True
221+
debug("set is async")
213222

214223
for stream in self._streams:
215224
# Open each stream as a StreamReader
@@ -220,6 +229,9 @@ async def switch_to_async(self) -> None:
220229
task = asyncio.create_task(self.process_stream(stream, reader))
221230
self.stream_tasks.append(task)
222231

232+
# give the tasks a chance to start
233+
await asyncio.sleep(0)
234+
223235
async def process_stream(
224236
self, stream: WrappedStream, reader: asyncio.StreamReader
225237
) -> None:

python/cog/server/worker.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,12 @@ async def _async_init(self) -> None:
158158
if self._events_async:
159159
debug("async_init finished")
160160
return
161+
# if AsyncConnection is created before switch_to_async, a race condition can cause drain to fail
162+
# and write, seemingly, to block
163+
# maybe because we're trying to call StreamWriter.write when no event loop is running?
164+
await self._stream_redirector.switch_to_async()
161165
self._events_async = AsyncConnection(self._events)
162166
await self._events_async.async_init()
163-
await self._stream_redirector.switch_to_async()
164167
debug("async_init done")
165168

166169
def _setup(self) -> None:
@@ -189,15 +192,19 @@ def _setup(self) -> None:
189192
debug("inspect")
190193
if inspect.iscoroutinefunction(self._predictor.setup):
191194

192-
async def setup_async():
195+
async def setup_async() -> None:
193196
# we prefer to not stop-start the event loop between these calls
194-
await sef._async_init() # this creates tasks
197+
await self._async_init() # this creates tasks
198+
debug("async_init done")
195199
await run_setup_async(self._predictor)
200+
debug("ran setup async")
196201

197202
self.loop.run_until_complete(setup_async())
203+
debug("setup_async loop done")
198204
else:
199205
debug("sync setup")
200206
run_setup(self._predictor)
207+
debug("_setup done inside ctx mgr")
201208
debug("_setup done")
202209

203210
@contextlib.contextmanager
@@ -220,8 +227,15 @@ def _handle_setup_error(self) -> Iterator[None]:
220227
done.error_detail = str(e)
221228
raise
222229
finally:
230+
# we can arrive here if there was an error setting up stream_redirector
231+
# for example, because drain failed
232+
# in this case this drain could block or fail
223233
debug("setup done, calling drain")
224-
self._stream_redirector.drain()
234+
try:
235+
self._stream_redirector.drain()
236+
except Exception as e:
237+
debug("exc", e)
238+
raise
225239
debug("sending setup done")
226240
self.send(("SETUP", done))
227241
debug("sent setup done")
@@ -242,6 +256,7 @@ def _loop_sync(self) -> None:
242256
self._stream_redirector.shutdown()
243257

244258
async def _loop_async(self) -> None:
259+
debug("loop async")
245260
await self._async_init()
246261
assert self._events_async
247262
tasks: dict[str, asyncio.Task[None]] = {}
@@ -307,10 +322,14 @@ def _emit_metric(self, name: str, value: "int | float") -> None:
307322

308323
def send(self, obj: Any) -> None:
309324
if self._events_async:
325+
debug("sending on async")
310326
self._events_async.send(obj)
327+
debug("sent on async")
311328
else:
329+
debug("send lock")
312330
with self._sync_events_lock:
313331
self._events.send(obj)
332+
debug("finished sync send")
314333

315334
def _mk_send(self, id: str) -> Callable[[PublicEventType], None]:
316335
def send(event: PublicEventType) -> None:

0 commit comments

Comments
 (0)