Skip to content

Commit 2eb10ac

Browse files
committed
Convert FlowControlMixin to StdoutWriterProtocol
Add support for waiting for stream close. And actually use the new class.
1 parent b3330f2 commit 2eb10ac

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

splitgpg2/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from typing import Optional, Dict, Callable, Awaitable, Tuple, Pattern, List, \
4747
Union, Any, TypeVar, Set, TYPE_CHECKING, Coroutine, Sequence, cast
4848

49+
from .stdiostream import StdoutWriterProtocol
50+
4951
if TYPE_CHECKING:
5052
from typing_extensions import Protocol
5153
from typing import TypeAlias
@@ -1405,7 +1407,7 @@ def open_stdinout_connection(*,
14051407

14061408
write_transport, write_protocol = loop.run_until_complete(
14071409
loop.connect_write_pipe(
1408-
lambda: asyncio.streams.FlowControlMixin(loop),
1410+
lambda: StdoutWriterProtocol(loop),
14091411
sys.stdout.buffer))
14101412
writer = asyncio.StreamWriter(write_transport, write_protocol, None, loop)
14111413

splitgpg2/stdiostream.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
# based on asyncio library:
33
# Copyright (C) 2001 Python Software Foundation
44
#
5+
# Copyright (C) 2024 Marek Marczykowski-Górecki
6+
# <marmarek@invisiblethingslab.com>
57
#
68

79
import collections
810
from asyncio import protocols, events
911

10-
11-
class FlowControlMixin(protocols.Protocol):
12+
class StdoutWriterProtocol(protocols.Protocol):
1213
"""Reusable flow control logic for StreamWriter.drain().
13-
1414
This implements the protocol methods pause_writing(),
1515
resume_writing() and connection_lost(). If the subclass overrides
1616
these it must call the super methods.
17-
1817
StreamWriter.drain() must wait for _drain_helper() coroutine.
1918
"""
2019

@@ -26,18 +25,15 @@ def __init__(self, loop=None):
2625
self._paused = False
2726
self._drain_waiters = collections.deque()
2827
self._connection_lost = False
28+
self._closed = self._loop.create_future()
2929

3030
def pause_writing(self):
3131
assert not self._paused
3232
self._paused = True
33-
if self._loop.get_debug():
34-
logger.debug("%r pauses writing", self)
3533

3634
def resume_writing(self):
3735
assert self._paused
3836
self._paused = False
39-
if self._loop.get_debug():
40-
logger.debug("%r resumes writing", self)
4137

4238
for waiter in self._drain_waiters:
4339
if not waiter.done():
@@ -55,6 +51,11 @@ def connection_lost(self, exc):
5551
waiter.set_result(None)
5652
else:
5753
waiter.set_exception(exc)
54+
if not self._closed.done():
55+
if exc is None:
56+
self._closed.set_result(None)
57+
else:
58+
self._closed.set_exception(exc)
5859

5960
async def _drain_helper(self):
6061
if self._connection_lost:
@@ -68,6 +69,6 @@ async def _drain_helper(self):
6869
finally:
6970
self._drain_waiters.remove(waiter)
7071

72+
# pylint: disable=unused-argument
7173
def _get_close_waiter(self, stream):
72-
raise NotImplementedError
73-
74+
return self._closed

0 commit comments

Comments
 (0)