Skip to content

Commit

Permalink
Convert workers to use PipeSendChannels or FdStreams
Browse files Browse the repository at this point in the history
  • Loading branch information
richardsheridan committed Dec 18, 2020
1 parent ec32582 commit e20b9e8
Showing 1 changed file with 99 additions and 11 deletions.
110 changes: 99 additions & 11 deletions trio/_worker_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
from collections import deque
from itertools import count
from multiprocessing import Pipe, Process

from ._core import open_nursery, RunVar, CancelScope, wait_readable
from multiprocessing.reduction import ForkingPickler

from ._core import (
open_nursery,
RunVar,
CancelScope,
wait_readable,
EndOfChannel,
BrokenResourceError,
)
from ._sync import CapacityLimiter
from ._threads import to_thread_run_sync
from ._timeouts import sleep_forever
Expand All @@ -19,11 +27,15 @@
_proc_counter = count()

if os.name == "nt":
from trio._windows_pipes import PipeSendChannel, PipeReceiveChannel
from ._wait_for_object import WaitForSingleObject

# TODO: This uses a thread per-process. Can we do better?
wait_sentinel = WaitForSingleObject
else:
from ._unix_pipes import FdStream
import struct

wait_sentinel = wait_readable


Expand Down Expand Up @@ -142,19 +154,16 @@ def worker_fn():

async def run_sync(self, sync_fn, *args):
# Neither this nor the child process should be waiting at this point
self._rehabilitate_pipes()
async with open_nursery() as nursery:
try:
# Monitor needed for pypy and other platforms that don't
# promptly raise EOFError
# promptly raise EndOfChannel
await nursery.start(self._child_monitor)

await to_thread_run_sync(
self._send_pipe.send, (sync_fn, args), cancellable=True
)
result = await to_thread_run_sync(
self._recv_pipe.recv, cancellable=True
)
except EOFError:
await self._send(ForkingPickler.dumps((sync_fn, args)))
result = ForkingPickler.loads(await self._recv())
except EndOfChannel:
# Likely the worker died while we were waiting on a pipe
self.kill() # Just make sure
# sleep and let the monitor raise the appropriate error to avoid
Expand Down Expand Up @@ -192,6 +201,85 @@ def join(self, timeout=None):
# _proc.join() doesn't report whether the join was successful
return self._proc._popen.wait(timeout) is not None

if os.name == "nt":

def _rehabilitate_pipes(self):
# These must be created in an async context, so defer so
# that this object can be instantiated in e.g. a thread
if not hasattr(self, "_send_chan"):
self._send_chan = PipeSendChannel(self._send_pipe.fileno())
self._recv_chan = PipeReceiveChannel(self._recv_pipe.fileno())
self._send = self._send_chan.send
self._recv = self._recv_chan.receive

def __del__(self):
# Avoid __del__ errors on cleanup: GH#174, GH#1767
# multiprocessing will close them for us
if hasattr(self, "_send_chan"):
self._send_chan._handle_holder.handle = -1
self._recv_chan._handle_holder.handle = -1

else:

def _rehabilitate_pipes(self):
# These must be created in an async context, so defer so
# that this object can be instantiated in e.g. a thread
if not hasattr(self, "_send_stream"):
self._send_stream = FdStream(self._send_pipe.fileno())
self._recv_stream = FdStream(self._recv_pipe.fileno())

async def _recv(self):
buf = await self._recv_exactly(4)
(size,) = struct.unpack("!i", buf)
if size == -1:
buf = await self._recv_exactly(8)
(size,) = struct.unpack("!Q", buf)
return await self._recv_exactly(size)

async def _recv_exactly(self, size):
result_bytes = bytearray()
while size:
partial_result = await self._recv_stream.receive_some(size)
num_recvd = len(partial_result)
if not num_recvd:
raise EndOfChannel("got end of file during message")
result_bytes.extend(partial_result)
if num_recvd > size: # pragma: no cover
raise RuntimeError("Oversized response")
else:
size -= num_recvd
return result_bytes

async def _send(self, buf):
n = len(buf)
if n > 0x7FFFFFFF:
pre_header = struct.pack("!i", -1)
header = struct.pack("!Q", n)
await self._send_stream.send_all(pre_header)
await self._send_stream.send_all(header)
await self._send_stream.send_all(buf)
else:
# For wire compatibility with 3.7 and lower
header = struct.pack("!i", n)
if n > 16384:
# The payload is large so Nagle's algorithm won't be triggered
# and we'd better avoid the cost of concatenation.
await self._send_stream.send_all(header)
await self._send_stream.send_all(buf)
else:
# Issue #20540: concatenate before sending, to avoid delays due
# to Nagle's algorithm on a TCP socket.
# Also note we want to avoid sending a 0-length buffer separately,
# to avoid "broken pipe" errors if the other end closed the pipe.
await self._send_stream.send_all(header + buf)

def __del__(self):
# Avoid __del__ errors on cleanup: GH#174, GH#1767
# multiprocessing will close them for us
if hasattr(self, "_send_stream"):
self._send_stream._fd_holder.fd = -1
self._recv_stream._fd_holder.fd = -1


async def to_process_run_sync(sync_fn, *args, cancellable=False, limiter=None):
"""Run sync_fn in a separate process
Expand Down Expand Up @@ -243,7 +331,7 @@ async def to_process_run_sync(sync_fn, *args, cancellable=False, limiter=None):
try:
with CancelScope(shield=not cancellable):
return await proc.run_sync(sync_fn, *args)
except BrokenPipeError:
except BrokenResourceError:
# Rare case where proc timed out even though it was still alive
# as we popped it. Just retry.
pass
Expand Down

0 comments on commit e20b9e8

Please sign in to comment.