Skip to content

Commit

Permalink
allow multiple waiters
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaraditya303 authored Aug 1, 2022
1 parent de388c0 commit f6ede6e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
35 changes: 17 additions & 18 deletions Lib/asyncio/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server')

import collections
import socket
import sys
import weakref
Expand Down Expand Up @@ -128,7 +129,7 @@ def __init__(self, loop=None):
else:
self._loop = loop
self._paused = False
self._drain_waiter = None
self._drain_waiters = collections.deque()
self._connection_lost = False

def pause_writing(self):
Expand All @@ -143,9 +144,8 @@ def resume_writing(self):
if self._loop.get_debug():
logger.debug("%r resumes writing", self)

waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
while self._drain_waiters:
waiter = self._drain_waiters.popleft()
if not waiter.done():
waiter.set_result(None)

Expand All @@ -154,27 +154,26 @@ def connection_lost(self, exc):
# Wake up the writer if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

while self._drain_waiters:
waiter = self._drain_waiters.popleft()
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter
self._drain_waiters.append(waiter)
try:
await waiter
finally:
self._drain_waiters.remove(waiter)

def _get_close_waiter(self, stream):
raise NotImplementedError
Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
self.assertEqual(cm.filename, __file__)
self.assertIs(protocol._loop, self.loop)

def test_multiple_drain(self):
# See https://github.com/python/cpython/issues/74116
drained = 0

async def drainer(stream):
nonlocal drained
await stream._drain_helper()
drained += 1

async def main():
loop = asyncio.get_running_loop()
stream = asyncio.streams.FlowControlMixin(loop)
stream.pause_writing()
loop.call_later(0.1, stream.resume_writing)
await asyncio.gather(*[drainer(stream) for _ in range(10)])
self.assertEqual(drained, 10)

self.loop.run_until_complete(main())

def test_drain_raises(self):
# See http://bugs.python.org/issue25441

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.

0 comments on commit f6ede6e

Please sign in to comment.