Skip to content

fix send race by handling logs in the main thread when async #1831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
7 changes: 6 additions & 1 deletion python/cog/server/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ def __init__(self, conn: Connection) -> None:
self.started = False

async def async_init(self) -> None:
if self.started:
return
fd = self.wrapped_conn.fileno()
# mp may have handled something already but let's dup so exit is clean
dup_fd = os.dup(fd)
sock = socket.socket(fileno=dup_fd)
sock.setblocking(False)
# we don't want to see EAGAIN, we'd rather wait
# however, perhaps this is wrong and in some cases this could still block terribly
# sock.setblocking(False)
sock.setblocking(True)
Comment on lines +28 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble understanding this in the context of the commented out code. Are those concerns for setblocking(True)? It'd be nice for this comment to provide enough relevant context for someone to pick this up if we need to revisit this behavior.

# TODO: use /proc/sys/net/core/rmem_max, but special-case language models
sz = 65536
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, sz)
Expand Down
109 changes: 103 additions & 6 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import asyncio
import io
import os
import selectors
import threading
import uuid
from typing import (
Callable,
Optional,
Sequence,
TextIO,
)
from typing import Callable, Optional, Sequence, TextIO


async def async_fdopen(fd: int) -> asyncio.StreamReader:
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop.create_task(loop.connect_read_pipe(lambda: protocol, os.fdopen(fd, "rb")))
return reader


class WrappedStream:
Expand Down Expand Up @@ -79,6 +83,7 @@ def __init__(
self.drain_token = uuid.uuid4().hex
self.drain_event = threading.Event()
self.terminate_token = uuid.uuid4().hex
self.is_async = False

if len(self._streams) == 0:
raise ValueError("provide at least one wrapped stream to redirect")
Expand All @@ -92,6 +97,10 @@ def __init__(
super().__init__(daemon=True)

def drain(self) -> None:
if self.is_async:
# if we're async, we assume that logs will be processed promptly,
# and we don't want to block the event loop
return
self.drain_event.clear()
for stream in self._streams:
stream.write(self.drain_token + "\n")
Expand All @@ -100,12 +109,20 @@ def drain(self) -> None:
raise RuntimeError("output streams failed to drain")

def shutdown(self) -> None:
if not self.is_alive():
return
for stream in self._streams:
stream.write(self.terminate_token + "\n")
stream.flush()
break # only need to write one terminate token
self.join()

async def shutdown_async(self) -> None:
for stream in self._streams:
stream.write(self.terminate_token + "\n")
stream.flush()
await asyncio.gather(*self.stream_tasks)

def run(self) -> None:
selector = selectors.DefaultSelector()

Expand Down Expand Up @@ -153,3 +170,83 @@ def run(self) -> None:
if drain_tokens_seen >= drain_tokens_needed:
self.drain_event.set()
drain_tokens_seen = 0

async def switch_to_async(self) -> None:
"""
This function is called when the main thread switches to being async.
It ensures that the behavior stays the same, but write_hook is only called
from the main thread.

1. Open each stream as a StreamReader.
2. Create a task for each stream that will process the results.
3. write_hook is called for each complete log line.
4. Drain and terminate tokens are handled correctly.
5. Once the async tasks are started, shut down the thread.

We must not call write_hook twice for the same data during the switch.
"""
# Drain the streams to ensure all buffered data is processed
try:
self.drain()
except RuntimeError:
raise

# Shut down the thread
# we do this before starting a coroutine that will also read from the same fd
# so that shutdown can find the terminate tokens correctly
self.shutdown()
self.stream_tasks = []
self.is_async = True

for stream in self._streams:
# Open each stream as a StreamReader
fd = stream.wrapped.fileno()
reader = await async_fdopen(fd)

# Create a task for each stream to process the results
task = asyncio.create_task(self.process_stream(stream, reader))
self.stream_tasks.append(task)

# give the tasks a chance to start
await asyncio.sleep(0)

async def process_stream(
self, stream: WrappedStream, reader: asyncio.StreamReader
) -> None:
buffer = io.StringIO()
drain_tokens_seen = 0
should_exit = False

async for line in reader:

if not line:
break

line = line.decode()

if not line.endswith("\n"):
buffer.write(line)
continue

full_line = buffer.getvalue() + line.strip()

# Reset buffer
buffer = io.StringIO()

if full_line.endswith(self.terminate_token):
full_line = full_line[: -len(self.terminate_token)]
should_exit = True

if full_line.endswith(self.drain_token):
drain_tokens_seen += 1
full_line = full_line[: -len(self.drain_token)]

if full_line:
# Call write_hook from the main thread
self._write_hook(stream.name, stream.original, full_line + "\n")

if drain_tokens_seen >= len(self._streams):
self.drain_event.set()
drain_tokens_seen = 0
if should_exit:
break
100 changes: 67 additions & 33 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import signal
import sys
import threading
import traceback
import types
from collections import defaultdict
Expand Down Expand Up @@ -98,6 +99,7 @@ def emit_metric(metric_name: str, metric_value: "float | int") -> None:


class _ChildWorker(_spawn.Process): # type: ignore

def __init__(
self,
predictor_ref: str,
Expand All @@ -107,12 +109,17 @@ def __init__(
self._predictor_ref = predictor_ref
self._predictor: Optional[BasePredictor] = None
self._events = events
self._events_async: Optional[AsyncConnection[tuple[str, PublicEventType]]] = (
None
)
self._process_logs_task: Optional[asyncio.Task[None]] = None
self._tee_output = tee_output
self._cancelable = False

super().__init__()

def run(self) -> None:
self._sync_events_lock = threading.Lock()
# If we're running at a shell, SIGINT will be sent to every process in
# the process group. We ignore it in the child process and require that
# shutdown is coordinated by the parent process.
Expand All @@ -139,10 +146,19 @@ def run(self) -> None:
# </could be moved into StreamRedirector>

self._setup()
self._loop()
self._stream_redirector.shutdown()
self._loop() # shuts down stream redirector the correct way
self._events.close()

async def _async_init(self) -> None:
if self._events_async:
return
# if AsyncConnection is created before switch_to_async, a race condition can cause drain to fail
# and write, seemingly, to block
# maybe because we're trying to call StreamWriter.write when no event loop is running?
await self._stream_redirector.switch_to_async()
self._events_async = AsyncConnection(self._events)
await self._events_async.async_init()

def _setup(self) -> None:
with self._handle_setup_error():
# we need to load the predictor to know if setup is async
Expand All @@ -161,6 +177,9 @@ def _setup(self) -> None:
if hasattr(self._predictor, "setup"):
if inspect.iscoroutinefunction(self._predictor.setup):
# we should probably handle Shutdown during this process?
# possibly we prefer to not stop-start the event loop
# between these calls
self.loop.run_until_complete(self._async_init())
self.loop.run_until_complete(run_setup_async(self._predictor))
else:
run_setup(self._predictor)
Expand All @@ -182,8 +201,14 @@ def _handle_setup_error(self) -> Iterator[None]:
done.error_detail = str(e)
raise
finally:
self._stream_redirector.drain()
self._events.send(("SETUP", done))
# we can arrive here if there was an error setting up stream_redirector
# for example, because drain failed
# in this case this drain could block or fail
try:
self._stream_redirector.drain()
except Exception:
raise
self.send(("SETUP", done))

def _loop_sync(self) -> None:
while True:
Expand All @@ -199,33 +224,34 @@ def _loop_sync(self) -> None:
pass
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)
self._stream_redirector.shutdown()

async def _loop_async(self) -> None:
events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection(
self._events
)
with events:
tasks: dict[str, asyncio.Task[None]] = {}
while True:
try:
ev = await events.recv()
except asyncio.CancelledError:
return
if isinstance(ev, Shutdown):
self._log("got shutdown event [async]")
return
if isinstance(ev, PredictionInput):
# keep track of these so they can be cancelled
tasks[ev.id] = asyncio.create_task(self._predict_async(ev))
elif isinstance(ev, Cancel):
# in async mode, cancel signals are ignored
# only Cancel events are ignored
if ev.id in tasks:
tasks[ev.id].cancel()
else:
print(f"Got unexpected cancellation: {ev}", file=sys.stderr)
await self._async_init()
assert self._events_async
tasks: dict[str, asyncio.Task[None]] = {}
while True:
try:
ev = await self._events_async.recv()
except asyncio.CancelledError:
break
if isinstance(ev, Shutdown):
self._log("got shutdown event [async]")
break # should this be return?
if isinstance(ev, PredictionInput):
# keep track of these so they can be cancelled
tasks[ev.id] = asyncio.create_task(self._predict_async(ev))
elif isinstance(ev, Cancel):
# in async mode, cancel signals are ignored
# only Cancel events are ignored
if ev.id in tasks:
tasks[ev.id].cancel()
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)
print(f"Got unexpected cancellation: {ev}", file=sys.stderr)
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)
await self._stream_redirector.shutdown_async()
self._events_async.close()

def _loop(self) -> None:
if is_async(get_predict(self._predictor)):
Expand Down Expand Up @@ -254,17 +280,24 @@ def _handle_predict_error(self, id: str) -> Iterator[None]:
self.prediction_id_context.reset(token)
self._cancelable = False
self._stream_redirector.drain()
self._events.send((id, done))
self.send((id, done))

def _emit_metric(self, name: str, value: "int | float") -> None:
prediction_id = self.prediction_id_context.get(None)
if prediction_id is None:
raise Exception("Tried to emit a metric outside a prediction context")
self._events.send((prediction_id, PredictionMetric(name, value)))
self.send((prediction_id, PredictionMetric(name, value)))

def send(self, obj: Any) -> None:
if self._events_async:
self._events_async.send(obj)
else:
with self._sync_events_lock:
self._events.send(obj)

def _mk_send(self, id: str) -> Callable[[PublicEventType], None]:
def send(event: PublicEventType) -> None:
self._events.send((id, event))
self.send((id, event))

return send

Expand Down Expand Up @@ -309,15 +342,16 @@ def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None

def _log(self, *messages: str, source: str = "stderr") -> None:
id = self.prediction_id_context.get("LOG")
self._events.send((id, Log(" ".join(messages), source=source)))
self.send((id, Log(" ".join(messages), source=source)))

def _stream_write_hook(
self, stream_name: str, original_stream: TextIO, data: str
) -> None:
if self._tee_output:
original_stream.write(data)
original_stream.flush()
# this won't work, this fn gets called from a thread, not the async task
# this won't record prediction_id, because
# this fn gets called from a thread, not the async task
self._log(data, source=stream_name)


Expand Down