Skip to content

bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() #7130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions Doc/library/asyncio-protocol.rst
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate
and control the receive buffer. Event loops can then use the buffer
provided by the protocol to avoid unnecessary data copies. This
can result in noticeable performance improvement for protocols that
receive big amounts of data. Sophisticated protocols can allocate
the buffer only once at creation time.
receive big amounts of data. Sophisticated protocols implementations
can allocate the buffer only once at creation time.

The following callbacks are called on :class:`BufferedProtocol`
instances:

.. method:: BufferedProtocol.get_buffer()
.. method:: BufferedProtocol.get_buffer(sizehint)

Called to allocate a new receive buffer. Must return an object
that implements the :ref:`buffer protocol <bufferobjects>`.
Called to allocate a new receive buffer.

*sizehint* is a recommended minimal size for the returned
buffer. It is acceptable to return smaller or bigger buffers
than what *sizehint* suggests. When set to -1, the buffer size
can be arbitrary. It is an error to return a zero-sized buffer.

Must return an object that implements the
:ref:`buffer protocol <bufferobjects>`.

.. method:: BufferedProtocol.buffer_updated(nbytes)

Expand Down
7 changes: 6 additions & 1 deletion Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _run_until_complete_cb(fut):
futures._get_loop(fut).stop()



class _SendfileFallbackProtocol(protocols.Protocol):
def __init__(self, transp):
if not isinstance(transp, transports._FlowControlMixin):
Expand Down Expand Up @@ -304,6 +303,9 @@ def close(self):

async def start_serving(self):
self._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self._loop)

async def serve_forever(self):
if self._serving_forever_fut is not None:
Expand Down Expand Up @@ -1363,6 +1365,9 @@ async def create_server(
ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self)

if self._debug:
logger.info("%r is serving", server)
Expand Down
52 changes: 41 additions & 11 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
super().__init__(extra, loop)
self._set_extra(sock)
self._sock = sock
self._protocol = protocol
self.set_protocol(protocol)
self._server = server
self._buffer = None # None or bytearray.
self._read_fut = None
Expand Down Expand Up @@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,

def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):
self._loop_reading_cb = None
self._paused = True
super().__init__(loop, sock, protocol, waiter, extra, server)
self._paused = False

self._reschedule_on_resume = False
self._loop.call_soon(self._loop_reading)
self._paused = False

if protocols._is_buffered_protocol(protocol):
self._loop_reading = self._loop_reading__get_buffer
def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
self._loop_reading_cb = self._loop_reading__get_buffer
else:
self._loop_reading = self._loop_reading__data_received
self._loop_reading_cb = self._loop_reading__data_received

self._loop.call_soon(self._loop_reading)
super().set_protocol(protocol)

if self.is_reading():
# reset reading callback / buffers / self._read_fut
self.pause_reading()
self.resume_reading()

def is_reading(self):
return not self._paused and not self._closing
Expand All @@ -179,6 +189,13 @@ def pause_reading(self):
self._paused = True

if self._read_fut is not None and not self._read_fut.done():
# TODO: This is an ugly hack to cancel the current read future
# *and* avoid potential race conditions, as read cancellation
# goes through `future.cancel()` and `loop.call_soon()`.
# We then use this special attribute in the reader callback to
# exit *immediately* without doing any cleanup/rescheduling.
self._read_fut.__asyncio_cancelled_on_pause__ = True

self._read_fut.cancel()
self._read_fut = None
self._reschedule_on_resume = True
Expand Down Expand Up @@ -210,7 +227,14 @@ def _loop_reading__on_eof(self):
if not keep_open:
self.close()

def _loop_reading__data_received(self, fut=None):
def _loop_reading(self, fut=None):
self._loop_reading_cb(fut)

def _loop_reading__data_received(self, fut):
if (fut is not None and
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
return

if self._paused:
self._reschedule_on_resume = True
return
Expand Down Expand Up @@ -253,14 +277,18 @@ def _loop_reading__data_received(self, fut=None):
if not self._closing:
raise
else:
self._read_fut.add_done_callback(self._loop_reading)
self._read_fut.add_done_callback(self._loop_reading__data_received)
finally:
if data:
self._protocol.data_received(data)
elif data == b'':
self._loop_reading__on_eof()

def _loop_reading__get_buffer(self, fut=None):
def _loop_reading__get_buffer(self, fut):
if (fut is not None and
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
return

if self._paused:
self._reschedule_on_resume = True
return
Expand Down Expand Up @@ -310,7 +338,9 @@ def _loop_reading__get_buffer(self, fut=None):
return

try:
buf = self._protocol.get_buffer()
buf = self._protocol.get_buffer(-1)
if not len(buf):
raise RuntimeError('get_buffer() returned an empty buffer')
except Exception as exc:
self._fatal_error(
exc, 'Fatal error: protocol.get_buffer() call failed.')
Expand All @@ -319,7 +349,7 @@ def _loop_reading__get_buffer(self, fut=None):
try:
# schedule a new read
self._read_fut = self._loop._proactor.recv_into(self._sock, buf)
self._read_fut.add_done_callback(self._loop_reading)
self._read_fut.add_done_callback(self._loop_reading__get_buffer)
except ConnectionAbortedError as exc:
if not self._closing:
self._fatal_error(exc, 'Fatal read error on pipe transport')
Expand Down
10 changes: 5 additions & 5 deletions Lib/asyncio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol):
* CL: connection_lost()
"""

def get_buffer(self):
def get_buffer(self, sizehint):
"""Called to allocate a new receive buffer.

*sizehint* is a recommended minimal size for the returned
buffer. When set to -1, the buffer size can be arbitrary.

Must return an object that implements the
:ref:`buffer protocol <bufferobjects>`.
It is an error to return a zero-sized buffer.
"""

def buffer_updated(self, nbytes):
Expand Down Expand Up @@ -185,7 +189,3 @@ def pipe_connection_lost(self, fd, exc):

def process_exited(self):
"""Called when subprocess has exited."""


def _is_buffered_protocol(proto):
return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received')
28 changes: 20 additions & 8 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,10 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
self._extra['peername'] = None
self._sock = sock
self._sock_fd = sock.fileno()
self._protocol = protocol
self._protocol_connected = True

self._protocol_connected = False
self.set_protocol(protocol)

self._server = server
self._buffer = self._buffer_factory()
self._conn_lost = 0 # Set when call to connection_lost scheduled.
Expand Down Expand Up @@ -640,6 +642,7 @@ def abort(self):

def set_protocol(self, protocol):
self._protocol = protocol
self._protocol_connected = True

def get_protocol(self):
return self._protocol
Expand Down Expand Up @@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):

if protocols._is_buffered_protocol(protocol):
self._read_ready = self._read_ready__get_buffer
else:
self._read_ready = self._read_ready__data_received

self._read_ready_cb = None
super().__init__(loop, sock, protocol, extra, server)
self._eof = False
self._paused = False
Expand All @@ -745,6 +744,14 @@ def __init__(self, loop, sock, protocol, waiter=None,
self._loop.call_soon(futures._set_result_unless_cancelled,
waiter, None)

def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
self._read_ready_cb = self._read_ready__get_buffer
else:
self._read_ready_cb = self._read_ready__data_received

super().set_protocol(protocol)

def is_reading(self):
return not self._paused and not self._closing

Expand All @@ -764,12 +771,17 @@ def resume_reading(self):
if self._loop.get_debug():
logger.debug("%r resumes reading", self)

def _read_ready(self):
self._read_ready_cb()

def _read_ready__get_buffer(self):
if self._conn_lost:
return

try:
buf = self._protocol.get_buffer()
buf = self._protocol.get_buffer(-1)
if not len(buf):
raise RuntimeError('get_buffer() returned an empty buffer')
except Exception as exc:
self._fatal_error(
exc, 'Fatal error: protocol.get_buffer() call failed.')
Expand Down
32 changes: 31 additions & 1 deletion Lib/asyncio/sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
self._waiter = waiter
self._loop = loop
self._app_protocol = app_protocol
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)
self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made)
self._sslpipe = None
Expand Down Expand Up @@ -522,7 +524,16 @@ def data_received(self, data):

for chunk in appdata:
if chunk:
self._app_protocol.data_received(chunk)
try:
if self._app_protocol_is_buffer:
_feed_data_to_bufferred_proto(
self._app_protocol, chunk)
else:
self._app_protocol.data_received(chunk)
except Exception as ex:
self._fatal_error(
ex, 'application protocol failed to receive SSL data')
return
else:
self._start_shutdown()
break
Expand Down Expand Up @@ -709,3 +720,22 @@ def _abort(self):
self._transport.abort()
finally:
self._finalize()


def _feed_data_to_bufferred_proto(proto, data):
data_len = len(data)
while data_len:
buf = proto.get_buffer(data_len)
buf_len = len(buf)
if not buf_len:
raise RuntimeError('get_buffer() returned an empty buffer')

if buf_len >= data_len:
buf[:data_len] = data
proto.buffer_updated(data_len)
return
else:
buf[:buf_len] = data[:buf_len]
proto.buffer_updated(buf_len)
data = data[buf_len:]
data_len = len(data)
4 changes: 4 additions & 0 deletions Lib/asyncio/unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from . import events
from . import futures
from . import selector_events
from . import tasks
from . import transports
from .log import logger

Expand Down Expand Up @@ -308,6 +309,9 @@ async def create_unix_server(
ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self)

return server

Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_asyncio/test_buffered_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, cb, con_lost_fut):
self.cb = cb
self.con_lost_fut = con_lost_fut

def get_buffer(self):
def get_buffer(self, sizehint):
self.buffer = bytearray(100)
return self.buffer

Expand Down
6 changes: 3 additions & 3 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,7 @@ async def connect(cmd=None, **kwds):

class SendfileBase:

DATA = b"12345abcde" * 16 * 1024 # 160 KiB
DATA = b"12345abcde" * 64 * 1024 # 64 KiB (don't use smaller sizes)

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -2452,7 +2452,7 @@ def test_sendfile_ssl_close_peer_after_receiving(self):
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))

def test_sendfile_close_peer_in_middle_of_receiving(self):
def test_sendfile_close_peer_in_the_middle_of_receiving(self):
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
Expand All @@ -2465,7 +2465,7 @@ def test_sendfile_close_peer_in_middle_of_receiving(self):
self.file.tell())
self.assertTrue(cli_proto.transport.is_closing())

def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):

def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
Expand Down
Loading