Skip to content

Commit 1cc0ee7

Browse files
asvetlovmiss-islington
authored andcommitted
bpo-36801: Fix waiting in StreamWriter.drain for closing SSL transport (GH-13098)
https://bugs.python.org/issue36801
1 parent e19a91e commit 1cc0ee7

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

Lib/asyncio/streams.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ async def _drain_helper(self):
199199
self._drain_waiter = waiter
200200
await waiter
201201

202+
def _get_close_waiter(self, stream):
203+
raise NotImplementedError
204+
202205

203206
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
204207
"""Helper class to adapt between Protocol and StreamReader.
@@ -315,6 +318,9 @@ def eof_received(self):
315318
return False
316319
return True
317320

321+
def _get_close_waiter(self, stream):
322+
return self._closed
323+
318324
def __del__(self):
319325
# Prevent reports about unhandled exceptions.
320326
# Better than self._closed._log_traceback = False hack
@@ -376,7 +382,7 @@ def is_closing(self):
376382
return self._transport.is_closing()
377383

378384
async def wait_closed(self):
379-
await self._protocol._closed
385+
await self._protocol._get_close_waiter(self)
380386

381387
def get_extra_info(self, name, default=None):
382388
return self._transport.get_extra_info(name, default)
@@ -394,13 +400,12 @@ async def drain(self):
394400
if exc is not None:
395401
raise exc
396402
if self._transport.is_closing():
397-
# Yield to the event loop so connection_lost() may be
398-
# called. Without this, _drain_helper() would return
399-
# immediately, and code that calls
400-
# write(...); await drain()
401-
# in a loop would never call connection_lost(), so it
402-
# would not see an error when the socket is closed.
403-
await sleep(0, loop=self._loop)
403+
# Wait for protocol.connection_lost() call
404+
# Raise connection closing error if any,
405+
# ConnectionResetError otherwise
406+
fut = self._protocol._get_close_waiter(self)
407+
await fut
408+
raise ConnectionResetError('Connection lost')
404409
await self._protocol._drain_helper()
405410

406411
async def aclose(self):

Lib/asyncio/subprocess.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, limit, loop, *, _asyncio_internal=False):
2626
self._transport = None
2727
self._process_exited = False
2828
self._pipe_fds = []
29+
self._stdin_closed = self._loop.create_future()
2930

3031
def __repr__(self):
3132
info = [self.__class__.__name__]
@@ -80,6 +81,10 @@ def pipe_connection_lost(self, fd, exc):
8081
if pipe is not None:
8182
pipe.close()
8283
self.connection_lost(exc)
84+
if exc is None:
85+
self._stdin_closed.set_result(None)
86+
else:
87+
self._stdin_closed.set_exception(exc)
8388
return
8489
if fd == 1:
8590
reader = self.stdout
@@ -106,6 +111,10 @@ def _maybe_close_transport(self):
106111
self._transport.close()
107112
self._transport = None
108113

114+
def _get_close_waiter(self, stream):
115+
if stream is self.stdin:
116+
return self._stdin_closed
117+
109118

110119
class Process:
111120
def __init__(self, transport, protocol, loop, *, _asyncio_internal=False):

Lib/test/test_asyncio/test_streams.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,29 @@ def test_open_unix_connection_no_loop_ssl(self):
109109

110110
self._basetest_open_connection_no_loop_ssl(conn_fut)
111111

112+
@unittest.skipIf(ssl is None, 'No ssl module')
113+
def test_drain_on_closed_writer_ssl(self):
114+
115+
async def inner(httpd):
116+
reader, writer = await asyncio.open_connection(
117+
*httpd.address,
118+
ssl=test_utils.dummy_ssl_context())
119+
120+
messages = []
121+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
122+
writer.write(b'GET / HTTP/1.0\r\n\r\n')
123+
data = await reader.read()
124+
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
125+
126+
writer.close()
127+
with self.assertRaises(ConnectionResetError):
128+
await writer.drain()
129+
130+
self.assertEqual(messages, [])
131+
132+
with test_utils.run_test_server(use_ssl=True) as httpd:
133+
self.loop.run_until_complete(inner(httpd))
134+
112135
def _basetest_open_connection_error(self, open_connection_fut):
113136
messages = []
114137
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Properly handle SSL connection closing in asyncio StreamWriter.drain() call.

0 commit comments

Comments
 (0)