Skip to content

Commit bc3a002

Browse files
miss-islington1st1
andauthored
bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() (GH-7130)
In this commit: * Support BufferedProtocol in set_protocol() and start_tls() * Fix proactor to cancel readers reliably * Update tests to be compatible with OpenSSL 1.1.1 * Clarify BufferedProtocol docs * Bump TLS tests timeouts to 60 seconds; eliminate possible race from start_serving * Rewrite test_start_tls_server_1 (cherry picked from commit dbf1022) Co-authored-by: Yury Selivanov <yury@magic.io>
1 parent f8fdb36 commit bc3a002

13 files changed

+379
-66
lines changed

Doc/library/asyncio-protocol.rst

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate
463463
and control the receive buffer. Event loops can then use the buffer
464464
provided by the protocol to avoid unnecessary data copies. This
465465
can result in noticeable performance improvement for protocols that
466-
receive big amounts of data. Sophisticated protocols can allocate
467-
the buffer only once at creation time.
466+
receive big amounts of data. Sophisticated protocols implementations
467+
can allocate the buffer only once at creation time.
468468

469469
The following callbacks are called on :class:`BufferedProtocol`
470470
instances:
471471

472-
.. method:: BufferedProtocol.get_buffer()
472+
.. method:: BufferedProtocol.get_buffer(sizehint)
473473

474-
Called to allocate a new receive buffer. Must return an object
475-
that implements the :ref:`buffer protocol <bufferobjects>`.
474+
Called to allocate a new receive buffer.
475+
476+
*sizehint* is a recommended minimal size for the returned
477+
buffer. It is acceptable to return smaller or bigger buffers
478+
than what *sizehint* suggests. When set to -1, the buffer size
479+
can be arbitrary. It is an error to return a zero-sized buffer.
480+
481+
Must return an object that implements the
482+
:ref:`buffer protocol <bufferobjects>`.
476483

477484
.. method:: BufferedProtocol.buffer_updated(nbytes)
478485

Lib/asyncio/base_events.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def _run_until_complete_cb(fut):
157157
futures._get_loop(fut).stop()
158158

159159

160-
161160
class _SendfileFallbackProtocol(protocols.Protocol):
162161
def __init__(self, transp):
163162
if not isinstance(transp, transports._FlowControlMixin):
@@ -304,6 +303,9 @@ def close(self):
304303

305304
async def start_serving(self):
306305
self._start_serving()
306+
# Skip one loop iteration so that all 'loop.add_reader'
307+
# go through.
308+
await tasks.sleep(0, loop=self._loop)
307309

308310
async def serve_forever(self):
309311
if self._serving_forever_fut is not None:
@@ -1363,6 +1365,9 @@ async def create_server(
13631365
ssl, backlog, ssl_handshake_timeout)
13641366
if start_serving:
13651367
server._start_serving()
1368+
# Skip one loop iteration so that all 'loop.add_reader'
1369+
# go through.
1370+
await tasks.sleep(0, loop=self)
13661371

13671372
if self._debug:
13681373
logger.info("%r is serving", server)

Lib/asyncio/proactor_events.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
3030
super().__init__(extra, loop)
3131
self._set_extra(sock)
3232
self._sock = sock
33-
self._protocol = protocol
33+
self.set_protocol(protocol)
3434
self._server = server
3535
self._buffer = None # None or bytearray.
3636
self._read_fut = None
@@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
159159

160160
def __init__(self, loop, sock, protocol, waiter=None,
161161
extra=None, server=None):
162+
self._loop_reading_cb = None
163+
self._paused = True
162164
super().__init__(loop, sock, protocol, waiter, extra, server)
163-
self._paused = False
165+
164166
self._reschedule_on_resume = False
167+
self._loop.call_soon(self._loop_reading)
168+
self._paused = False
165169

166-
if protocols._is_buffered_protocol(protocol):
167-
self._loop_reading = self._loop_reading__get_buffer
170+
def set_protocol(self, protocol):
171+
if isinstance(protocol, protocols.BufferedProtocol):
172+
self._loop_reading_cb = self._loop_reading__get_buffer
168173
else:
169-
self._loop_reading = self._loop_reading__data_received
174+
self._loop_reading_cb = self._loop_reading__data_received
170175

171-
self._loop.call_soon(self._loop_reading)
176+
super().set_protocol(protocol)
177+
178+
if self.is_reading():
179+
# reset reading callback / buffers / self._read_fut
180+
self.pause_reading()
181+
self.resume_reading()
172182

173183
def is_reading(self):
174184
return not self._paused and not self._closing
@@ -179,6 +189,13 @@ def pause_reading(self):
179189
self._paused = True
180190

181191
if self._read_fut is not None and not self._read_fut.done():
192+
# TODO: This is an ugly hack to cancel the current read future
193+
# *and* avoid potential race conditions, as read cancellation
194+
# goes through `future.cancel()` and `loop.call_soon()`.
195+
# We then use this special attribute in the reader callback to
196+
# exit *immediately* without doing any cleanup/rescheduling.
197+
self._read_fut.__asyncio_cancelled_on_pause__ = True
198+
182199
self._read_fut.cancel()
183200
self._read_fut = None
184201
self._reschedule_on_resume = True
@@ -210,7 +227,14 @@ def _loop_reading__on_eof(self):
210227
if not keep_open:
211228
self.close()
212229

213-
def _loop_reading__data_received(self, fut=None):
230+
def _loop_reading(self, fut=None):
231+
self._loop_reading_cb(fut)
232+
233+
def _loop_reading__data_received(self, fut):
234+
if (fut is not None and
235+
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
236+
return
237+
214238
if self._paused:
215239
self._reschedule_on_resume = True
216240
return
@@ -253,14 +277,18 @@ def _loop_reading__data_received(self, fut=None):
253277
if not self._closing:
254278
raise
255279
else:
256-
self._read_fut.add_done_callback(self._loop_reading)
280+
self._read_fut.add_done_callback(self._loop_reading__data_received)
257281
finally:
258282
if data:
259283
self._protocol.data_received(data)
260284
elif data == b'':
261285
self._loop_reading__on_eof()
262286

263-
def _loop_reading__get_buffer(self, fut=None):
287+
def _loop_reading__get_buffer(self, fut):
288+
if (fut is not None and
289+
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
290+
return
291+
264292
if self._paused:
265293
self._reschedule_on_resume = True
266294
return
@@ -310,7 +338,9 @@ def _loop_reading__get_buffer(self, fut=None):
310338
return
311339

312340
try:
313-
buf = self._protocol.get_buffer()
341+
buf = self._protocol.get_buffer(-1)
342+
if not len(buf):
343+
raise RuntimeError('get_buffer() returned an empty buffer')
314344
except Exception as exc:
315345
self._fatal_error(
316346
exc, 'Fatal error: protocol.get_buffer() call failed.')
@@ -319,7 +349,7 @@ def _loop_reading__get_buffer(self, fut=None):
319349
try:
320350
# schedule a new read
321351
self._read_fut = self._loop._proactor.recv_into(self._sock, buf)
322-
self._read_fut.add_done_callback(self._loop_reading)
352+
self._read_fut.add_done_callback(self._loop_reading__get_buffer)
323353
except ConnectionAbortedError as exc:
324354
if not self._closing:
325355
self._fatal_error(exc, 'Fatal read error on pipe transport')

Lib/asyncio/protocols.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol):
130130
* CL: connection_lost()
131131
"""
132132

133-
def get_buffer(self):
133+
def get_buffer(self, sizehint):
134134
"""Called to allocate a new receive buffer.
135135
136+
*sizehint* is a recommended minimal size for the returned
137+
buffer. When set to -1, the buffer size can be arbitrary.
138+
136139
Must return an object that implements the
137140
:ref:`buffer protocol <bufferobjects>`.
141+
It is an error to return a zero-sized buffer.
138142
"""
139143

140144
def buffer_updated(self, nbytes):
@@ -185,7 +189,3 @@ def pipe_connection_lost(self, fd, exc):
185189

186190
def process_exited(self):
187191
"""Called when subprocess has exited."""
188-
189-
190-
def _is_buffered_protocol(proto):
191-
return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received')

Lib/asyncio/selector_events.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,10 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
597597
self._extra['peername'] = None
598598
self._sock = sock
599599
self._sock_fd = sock.fileno()
600-
self._protocol = protocol
601-
self._protocol_connected = True
600+
601+
self._protocol_connected = False
602+
self.set_protocol(protocol)
603+
602604
self._server = server
603605
self._buffer = self._buffer_factory()
604606
self._conn_lost = 0 # Set when call to connection_lost scheduled.
@@ -640,6 +642,7 @@ def abort(self):
640642

641643
def set_protocol(self, protocol):
642644
self._protocol = protocol
645+
self._protocol_connected = True
643646

644647
def get_protocol(self):
645648
return self._protocol
@@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport):
721724
def __init__(self, loop, sock, protocol, waiter=None,
722725
extra=None, server=None):
723726

724-
if protocols._is_buffered_protocol(protocol):
725-
self._read_ready = self._read_ready__get_buffer
726-
else:
727-
self._read_ready = self._read_ready__data_received
728-
727+
self._read_ready_cb = None
729728
super().__init__(loop, sock, protocol, extra, server)
730729
self._eof = False
731730
self._paused = False
@@ -745,6 +744,14 @@ def __init__(self, loop, sock, protocol, waiter=None,
745744
self._loop.call_soon(futures._set_result_unless_cancelled,
746745
waiter, None)
747746

747+
def set_protocol(self, protocol):
748+
if isinstance(protocol, protocols.BufferedProtocol):
749+
self._read_ready_cb = self._read_ready__get_buffer
750+
else:
751+
self._read_ready_cb = self._read_ready__data_received
752+
753+
super().set_protocol(protocol)
754+
748755
def is_reading(self):
749756
return not self._paused and not self._closing
750757

@@ -764,12 +771,17 @@ def resume_reading(self):
764771
if self._loop.get_debug():
765772
logger.debug("%r resumes reading", self)
766773

774+
def _read_ready(self):
775+
self._read_ready_cb()
776+
767777
def _read_ready__get_buffer(self):
768778
if self._conn_lost:
769779
return
770780

771781
try:
772-
buf = self._protocol.get_buffer()
782+
buf = self._protocol.get_buffer(-1)
783+
if not len(buf):
784+
raise RuntimeError('get_buffer() returned an empty buffer')
773785
except Exception as exc:
774786
self._fatal_error(
775787
exc, 'Fatal error: protocol.get_buffer() call failed.')

Lib/asyncio/sslproto.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
441441
self._waiter = waiter
442442
self._loop = loop
443443
self._app_protocol = app_protocol
444+
self._app_protocol_is_buffer = \
445+
isinstance(app_protocol, protocols.BufferedProtocol)
444446
self._app_transport = _SSLProtocolTransport(self._loop, self)
445447
# _SSLPipe instance (None until the connection is made)
446448
self._sslpipe = None
@@ -522,7 +524,16 @@ def data_received(self, data):
522524

523525
for chunk in appdata:
524526
if chunk:
525-
self._app_protocol.data_received(chunk)
527+
try:
528+
if self._app_protocol_is_buffer:
529+
_feed_data_to_bufferred_proto(
530+
self._app_protocol, chunk)
531+
else:
532+
self._app_protocol.data_received(chunk)
533+
except Exception as ex:
534+
self._fatal_error(
535+
ex, 'application protocol failed to receive SSL data')
536+
return
526537
else:
527538
self._start_shutdown()
528539
break
@@ -709,3 +720,22 @@ def _abort(self):
709720
self._transport.abort()
710721
finally:
711722
self._finalize()
723+
724+
725+
def _feed_data_to_bufferred_proto(proto, data):
726+
data_len = len(data)
727+
while data_len:
728+
buf = proto.get_buffer(data_len)
729+
buf_len = len(buf)
730+
if not buf_len:
731+
raise RuntimeError('get_buffer() returned an empty buffer')
732+
733+
if buf_len >= data_len:
734+
buf[:data_len] = data
735+
proto.buffer_updated(data_len)
736+
return
737+
else:
738+
buf[:buf_len] = data[:buf_len]
739+
proto.buffer_updated(buf_len)
740+
data = data[buf_len:]
741+
data_len = len(data)

Lib/asyncio/unix_events.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from . import events
2121
from . import futures
2222
from . import selector_events
23+
from . import tasks
2324
from . import transports
2425
from .log import logger
2526

@@ -308,6 +309,9 @@ async def create_unix_server(
308309
ssl, backlog, ssl_handshake_timeout)
309310
if start_serving:
310311
server._start_serving()
312+
# Skip one loop iteration so that all 'loop.add_reader'
313+
# go through.
314+
await tasks.sleep(0, loop=self)
311315

312316
return server
313317

Lib/test/test_asyncio/test_buffered_proto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, cb, con_lost_fut):
99
self.cb = cb
1010
self.con_lost_fut = con_lost_fut
1111

12-
def get_buffer(self):
12+
def get_buffer(self, sizehint):
1313
self.buffer = bytearray(100)
1414
return self.buffer
1515

Lib/test/test_asyncio/test_events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,7 @@ async def connect(cmd=None, **kwds):
20952095

20962096
class SendfileBase:
20972097

2098-
DATA = b"12345abcde" * 16 * 1024 # 160 KiB
2098+
DATA = b"12345abcde" * 64 * 1024 # 64 KiB (don't use smaller sizes)
20992099

21002100
@classmethod
21012101
def setUpClass(cls):
@@ -2452,7 +2452,7 @@ def test_sendfile_ssl_close_peer_after_receiving(self):
24522452
self.assertEqual(srv_proto.data, self.DATA)
24532453
self.assertEqual(self.file.tell(), len(self.DATA))
24542454

2455-
def test_sendfile_close_peer_in_middle_of_receiving(self):
2455+
def test_sendfile_close_peer_in_the_middle_of_receiving(self):
24562456
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
24572457
with self.assertRaises(ConnectionError):
24582458
self.run_loop(
@@ -2465,7 +2465,7 @@ def test_sendfile_close_peer_in_middle_of_receiving(self):
24652465
self.file.tell())
24662466
self.assertTrue(cli_proto.transport.is_closing())
24672467

2468-
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
2468+
def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
24692469

24702470
def sendfile_native(transp, file, offset, count):
24712471
# to raise SendfileNotAvailableError

0 commit comments

Comments
 (0)