Skip to content

Commit

Permalink
GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-9…
Browse files Browse the repository at this point in the history
…4705)

(cherry picked from commit e5b2453)

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
  • Loading branch information
miss-islington and kumaraditya303 authored Sep 8, 2022
1 parent 280130f commit f60bbf0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 19 deletions.
35 changes: 16 additions & 19 deletions Lib/asyncio/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server')

import collections
import socket
import sys
import warnings
Expand Down Expand Up @@ -129,7 +130,7 @@ def __init__(self, loop=None):
else:
self._loop = loop
self._paused = False
self._drain_waiter = None
self._drain_waiters = collections.deque()
self._connection_lost = False

def pause_writing(self):
Expand All @@ -144,38 +145,34 @@ def resume_writing(self):
if self._loop.get_debug():
logger.debug("%r resumes writing", self)

waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
for waiter in self._drain_waiters:
if not waiter.done():
waiter.set_result(None)

def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused.
# Wake up the writer(s) if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

for waiter in self._drain_waiters:
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter
self._drain_waiters.append(waiter)
try:
await waiter
finally:
self._drain_waiters.remove(waiter)

def _get_close_waiter(self, stream):
raise NotImplementedError
Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
self.assertEqual(cm.warnings[0].filename, __file__)
self.assertIs(protocol._loop, self.loop)

def test_multiple_drain(self):
# See https://github.com/python/cpython/issues/74116
drained = 0

async def drainer(stream):
nonlocal drained
await stream._drain_helper()
drained += 1

async def main():
loop = asyncio.get_running_loop()
stream = asyncio.streams.FlowControlMixin(loop)
stream.pause_writing()
loop.call_later(0.1, stream.resume_writing)
await asyncio.gather(*[drainer(stream) for _ in range(10)])
self.assertEqual(drained, 10)

self.loop.run_until_complete(main())

def test_drain_raises(self):
# See http://bugs.python.org/issue25441

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.

0 comments on commit f60bbf0

Please sign in to comment.