Skip to content

Fix race between future cancellation and remove_reader|writer #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,91 @@ def test_socket_sync_remove_and_immediately_close(self):
self.assertEqual(sock.fileno(), -1)
self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))

def test_sock_cancel_add_reader_race(self):
srv_sock_conn = None

async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()), loop=self.loop)
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
srv_sock_conn.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with srv_sock_conn:
await fut

async def client(addr):
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[self.loop.sock_recv(sock_client, 1)],
timeout=1, loop=self.loop)

async def send_server_data():
# Wait a little bit to let reader future cancel and
# schedule the removal of the reader callback. Right after
# "rfut.cancel()" we will call "loop.sock_recv()", which
# will add a reader. This will make a race between
# remove- and add-reader.
await asyncio.sleep(0.1, loop=self.loop)
await self.loop.sock_sendall(srv_sock_conn, b'1')
self.loop.create_task(send_server_data())

for rfut in pending_read_futs:
rfut.cancel()

data = await self.loop.sock_recv(sock_client, 1)

self.assertEqual(data, b'1')

self.loop.run_until_complete(server())

def test_sock_send_before_cancel(self):
srv_sock_conn = None

async def server():
nonlocal srv_sock_conn
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_server.setblocking(False)
with sock_server:
sock_server.bind(('127.0.0.1', 0))
sock_server.listen()
fut = asyncio.ensure_future(
client(sock_server.getsockname()), loop=self.loop)
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
with srv_sock_conn:
await fut

async def client(addr):
await asyncio.sleep(0.01, loop=self.loop)
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_client.setblocking(False)
with sock_client:
await self.loop.sock_connect(sock_client, addr)
_, pending_read_futs = await asyncio.wait(
[self.loop.sock_recv(sock_client, 1)],
timeout=1, loop=self.loop)

# server can send the data in a random time, even before
# the previous result future has cancelled.
await self.loop.sock_sendall(srv_sock_conn, b'1')

for rfut in pending_read_futs:
rfut.cancel()

data = await self.loop.sock_recv(sock_client, 1)

self.assertEqual(data, b'1')

self.loop.run_until_complete(server())


class TestUVSockets(_TestSockets, tb.UVTestCase):

Expand Down
2 changes: 2 additions & 0 deletions uvloop/handles/poll.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cdef class UVPoll(UVHandle):
cdef int is_active(self)

cdef is_reading(self)
cdef is_writing(self)

cdef start_reading(self, Handle callback)
cdef start_writing(self, Handle callback)
cdef stop_reading(self)
Expand Down
3 changes: 3 additions & 0 deletions uvloop/handles/poll.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ cdef class UVPoll(UVHandle):
cdef is_reading(self):
return self._is_alive() and self.reading_handle is not None

cdef is_writing(self):
return self._is_alive() and self.writing_handle is not None

cdef start_reading(self, Handle callback):
cdef:
int mask = 0
Expand Down
4 changes: 2 additions & 2 deletions uvloop/loop.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ cdef class Loop:
cdef _track_process(self, UVProcess proc)
cdef _untrack_process(self, UVProcess proc)

cdef _new_reader_future(self, sock)
cdef _new_writer_future(self, sock)
cdef _add_reader(self, fd, Handle handle)
cdef _has_reader(self, fd)
cdef _remove_reader(self, fd)

cdef _add_writer(self, fd, Handle handle)
cdef _has_writer(self, fd)
cdef _remove_writer(self, fd)

cdef _sock_recv(self, fut, sock, n)
Expand Down
130 changes: 93 additions & 37 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,20 @@ cdef class Loop:

return result

cdef _has_reader(self, fileobj):
cdef:
UVPoll poll

self._check_closed()
fd = self._fileobj_to_fd(fileobj)

try:
poll = <UVPoll>(self._polls[fd])
except KeyError:
return False

return poll.is_reading()

cdef _add_writer(self, fileobj, Handle handle):
cdef:
UVPoll poll
Expand Down Expand Up @@ -791,6 +805,20 @@ cdef class Loop:

return result

cdef _has_writer(self, fileobj):
cdef:
UVPoll poll

self._check_closed()
fd = self._fileobj_to_fd(fileobj)

try:
poll = <UVPoll>(self._polls[fd])
except KeyError:
return False

return poll.is_writing()

cdef _getaddrinfo(self, object host, object port,
int family, int type,
int proto, int flags,
Expand Down Expand Up @@ -845,35 +873,17 @@ cdef class Loop:
nr.query(addr, flags)
return fut

cdef _new_reader_future(self, sock):
def _on_cancel(fut):
# Check if the future was cancelled and if the socket
# is still open, i.e.
#
# loop.remove_reader(sock)
# sock.close()
# fut.cancel()
#
# wasn't called by the user.
if fut.cancelled() and sock.fileno() != -1:
self._remove_reader(sock)

fut = self._new_future()
fut.add_done_callback(_on_cancel)
return fut

cdef _new_writer_future(self, sock):
def _on_cancel(fut):
if fut.cancelled() and sock.fileno() != -1:
self._remove_writer(sock)

fut = self._new_future()
fut.add_done_callback(_on_cancel)
return fut

cdef _sock_recv(self, fut, sock, n):
cdef:
Handle handle
if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_recv is called on a cancelled Future')

if not self._has_reader(sock):
raise RuntimeError(
f'socket {sock!r} does not have a reader '
f'in the _sock_recv callback')

try:
data = sock.recv(n)
Expand All @@ -889,8 +899,16 @@ cdef class Loop:
self._remove_reader(sock)

cdef _sock_recv_into(self, fut, sock, buf):
cdef:
Handle handle
if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_recv_into is called on a cancelled Future')

if not self._has_reader(sock):
raise RuntimeError(
f'socket {sock!r} does not have a reader '
f'in the _sock_recv_into callback')

try:
data = sock.recv_into(buf)
Expand All @@ -910,6 +928,17 @@ cdef class Loop:
Handle handle
int n

if UVLOOP_DEBUG:
if fut.cancelled():
# Shouldn't happen with _SyncSocketReaderFuture.
raise RuntimeError(
f'_sock_sendall is called on a cancelled Future')

if not self._has_writer(sock):
raise RuntimeError(
f'socket {sock!r} does not have a writer '
f'in the _sock_sendall callback')

try:
n = sock.send(data)
except (BlockingIOError, InterruptedError):
Expand Down Expand Up @@ -940,9 +969,6 @@ cdef class Loop:
self._add_writer(sock, handle)

cdef _sock_accept(self, fut, sock):
cdef:
Handle handle

try:
conn, address = sock.accept()
conn.setblocking(False)
Expand Down Expand Up @@ -2261,7 +2287,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_recv",
Expand All @@ -2287,7 +2313,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_recv_into",
Expand Down Expand Up @@ -2338,7 +2364,7 @@ cdef class Loop:
data = memoryview(data)
data = data[n:]

fut = self._new_writer_future(sock)
fut = _SyncSocketWriterFuture(sock, self)
handle = new_MethodHandle3(
self,
"Loop._sock_sendall",
Expand Down Expand Up @@ -2368,7 +2394,7 @@ cdef class Loop:
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self._new_reader_future(sock)
fut = _SyncSocketReaderFuture(sock, self)
handle = new_MethodHandle2(
self,
"Loop._sock_accept",
Expand Down Expand Up @@ -2952,6 +2978,36 @@ cdef inline void __loop_free_buffer(Loop loop):
loop._recv_buffer_in_use = 0


class _SyncSocketReaderFuture(aio_Future):

def __init__(self, sock, loop):
aio_Future.__init__(self, loop=loop)
self.__sock = sock
self.__loop = loop

def cancel(self):
if self.__sock is not None and self.__sock.fileno() != -1:
self.__loop.remove_reader(self.__sock)
self.__sock = None

aio_Future.cancel(self)


class _SyncSocketWriterFuture(aio_Future):

def __init__(self, sock, loop):
aio_Future.__init__(self, loop=loop)
self.__sock = sock
self.__loop = loop

def cancel(self):
if self.__sock is not None and self.__sock.fileno() != -1:
self.__loop.remove_writer(self.__sock)
self.__sock = None

aio_Future.cancel(self)


include "cbhandles.pyx"
include "pseudosock.pyx"

Expand Down