Skip to content

bpo-36801: Fix waiting in StreamWriter.drain for closing SSL transport #13098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions Lib/asyncio/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ async def _drain_helper(self):
self._drain_waiter = waiter
await waiter

def _get_close_waiter(self, stream):
raise NotImplementedError


class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""Helper class to adapt between Protocol and StreamReader.
Expand Down Expand Up @@ -293,6 +296,9 @@ def eof_received(self):
return False
return True

def _get_close_waiter(self, stream):
return self._closed

def __del__(self):
# Prevent reports about unhandled exceptions.
# Better than self._closed._log_traceback = False hack
Expand Down Expand Up @@ -348,7 +354,7 @@ def is_closing(self):
return self._transport.is_closing()

async def wait_closed(self):
await self._protocol._closed
await self._protocol._get_close_waiter(self)

def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default)
Expand All @@ -366,13 +372,12 @@ async def drain(self):
if exc is not None:
raise exc
if self._transport.is_closing():
# Yield to the event loop so connection_lost() may be
# called. Without this, _drain_helper() would return
# immediately, and code that calls
# write(...); await drain()
# in a loop would never call connection_lost(), so it
# would not see an error when the socket is closed.
await sleep(0, loop=self._loop)
# Wait for protocol.connection_lost() call
# Raise connection closing error if any,
# ConnectionResetError otherwise
fut = self._protocol._get_close_waiter(self)
await fut
raise ConnectionResetError('Connection lost')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to raise an error from drain()? Maybe it should just return? What's the point of this error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how protocol._drain_helper() works now.
The future is set to exception only if the connection is closed with a failure.
But await writer.drain() raises ConnectionResetError.
I believe this is good: await writer.write(b'data') should fail loudly if the socket is closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the code could be simplified by setting ConnectionResetError by connection_lost() callback handler but I'd like to keep the PR small.
Future improvement worth another pull request.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, got it.

await self._protocol._drain_helper()

async def aclose(self):
Expand Down
9 changes: 9 additions & 0 deletions Lib/asyncio/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, limit, loop):
self._transport = None
self._process_exited = False
self._pipe_fds = []
self._stdin_closed = self._loop.create_future()

def __repr__(self):
info = [self.__class__.__name__]
Expand Down Expand Up @@ -76,6 +77,10 @@ def pipe_connection_lost(self, fd, exc):
if pipe is not None:
pipe.close()
self.connection_lost(exc)
if exc is None:
self._stdin_closed.set_result(None)
else:
self._stdin_closed.set_exception(exc)
return
if fd == 1:
reader = self.stdout
Expand All @@ -102,6 +107,10 @@ def _maybe_close_transport(self):
self._transport.close()
self._transport = None

def _get_close_waiter(self, stream):
if stream is self.stdin:
return self._stdin_closed


class Process:
def __init__(self, transport, protocol, loop):
Expand Down
23 changes: 23 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,29 @@ def test_open_unix_connection_no_loop_ssl(self):

self._basetest_open_connection_no_loop_ssl(conn_fut)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_drain_on_closed_writer_ssl(self):

async def inner(httpd):
reader, writer = await asyncio.open_connection(
*httpd.address,
ssl=test_utils.dummy_ssl_context())

messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
writer.write(b'GET / HTTP/1.0\r\n\r\n')
data = await reader.read()
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))

writer.close()
with self.assertRaises(ConnectionResetError):
await writer.drain()

self.assertEqual(messages, [])

with test_utils.run_test_server(use_ssl=True) as httpd:
self.loop.run_until_complete(inner(httpd))

def _basetest_open_connection_error(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Properly handle SSL connection closing in asyncio StreamWriter.drain() call.