Skip to content

Commit 4b4958a

Browse files
committed
Changed eof_received() to reflect close_notify
1 parent 9db63f4 commit 4b4958a

File tree

2 files changed

+146
-9
lines changed

2 files changed

+146
-9
lines changed

tests/test_tcp.py

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,7 @@ def serve(sock):
15731573
data = sock.recv_all(len(HELLO_MSG))
15741574
self.assertEqual(len(data), len(HELLO_MSG))
15751575

1576+
sock.unwrap()
15761577
sock.shutdown(socket.SHUT_RDWR)
15771578
sock.close()
15781579

@@ -1643,6 +1644,7 @@ def serve(sock):
16431644
data = sock.recv_all(len(HELLO_MSG))
16441645
self.assertEqual(len(data), len(HELLO_MSG))
16451646

1647+
sock.unwrap()
16461648
sock.shutdown(socket.SHUT_RDWR)
16471649
sock.close()
16481650

@@ -1798,6 +1800,7 @@ def client(sock, addr):
17981800
sock.starttls(client_context)
17991801
sock.sendall(HELLO_MSG)
18001802

1803+
sock.unwrap()
18011804
sock.shutdown(socket.SHUT_RDWR)
18021805
sock.close()
18031806

@@ -2303,9 +2306,11 @@ def test_write_to_closed_transport(self):
23032306

23042307
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
23052308
client_sslctx = self._create_client_ssl_context()
2309+
future = None
23062310

23072311
def server(sock):
23082312
sock.starttls(sslctx, server_side=True)
2313+
sock.shutdown(socket.SHUT_RDWR)
23092314
sock.close()
23102315

23112316
def unwrap_server(sock):
@@ -2320,6 +2325,9 @@ def unwrap_server(sock):
23202325
sock.close()
23212326

23222327
async def client(addr):
2328+
nonlocal future
2329+
future = self.loop.create_future()
2330+
23232331
reader, writer = await asyncio.open_connection(
23242332
*addr,
23252333
ssl=client_sslctx,
@@ -2330,7 +2338,7 @@ async def client(addr):
23302338
try:
23312339
data = await reader.read()
23322340
self.assertEqual(data, b'')
2333-
except ConnectionResetError:
2341+
except (ConnectionResetError, BrokenPipeError):
23342342
pass
23352343

23362344
for i in range(25):
@@ -2339,11 +2347,23 @@ async def client(addr):
23392347
self.assertEqual(
23402348
len(writer.transport._ssl_protocol._write_backlog), 0)
23412349

2350+
await future
2351+
2352+
def run(meth):
2353+
def wrapper(sock):
2354+
try:
2355+
meth(sock)
2356+
except Exception as ex:
2357+
self.loop.call_soon_threadsafe(future.set_exception, ex)
2358+
else:
2359+
self.loop.call_soon_threadsafe(future.set_result, None)
2360+
return wrapper
2361+
23422362
with self._silence_eof_received_warning():
2343-
with self.tcp_server(server) as srv:
2363+
with self.tcp_server(run(server)) as srv:
23442364
self.loop.run_until_complete(client(srv.addr))
23452365

2346-
with self.tcp_server(unwrap_server) as srv:
2366+
with self.tcp_server(run(unwrap_server)) as srv:
23472367
self.loop.run_until_complete(client(srv.addr))
23482368

23492369
def test_flush_before_shutdown(self):
@@ -2438,6 +2458,98 @@ async def client(addr):
24382458
with self.tcp_server(run(openssl_server)) as srv:
24392459
self.loop.run_until_complete(client(srv.addr))
24402460

2461+
def test_remote_shutdown_receives_trailing_data(self):
2462+
if self.implementation == 'asyncio':
2463+
raise unittest.SkipTest()
2464+
2465+
CHUNK = 1024 * 128
2466+
SIZE = 32
2467+
2468+
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2469+
client_sslctx = self._create_client_ssl_context()
2470+
future = None
2471+
2472+
def server(sock):
2473+
sock.starttls(sslctx, server_side=True)
2474+
self.assertEqual(sock.recv_all(4), b'ping')
2475+
sock.send(b'pong')
2476+
2477+
time.sleep(0.2) # wait for the peer to fill its backlog
2478+
2479+
# send close_notify but don't wait for response
2480+
sock.setblocking(0)
2481+
with self.assertRaises(ssl.SSLWantReadError):
2482+
sock.unwrap()
2483+
sock.setblocking(1)
2484+
2485+
# should receive all data
2486+
data = sock.recv_all(CHUNK * SIZE)
2487+
self.assertEqual(len(data), CHUNK * SIZE)
2488+
2489+
# wait for close_notify
2490+
sock.unwrap()
2491+
2492+
sock.close()
2493+
2494+
def eof_server(sock):
2495+
sock.starttls(sslctx, server_side=True)
2496+
self.assertEqual(sock.recv_all(4), b'ping')
2497+
sock.send(b'pong')
2498+
2499+
time.sleep(0.2) # wait for the peer to fill its backlog
2500+
2501+
# send EOF
2502+
sock.shutdown(socket.SHUT_WR)
2503+
2504+
# should receive all data
2505+
data = sock.recv_all(CHUNK * SIZE)
2506+
self.assertEqual(len(data), CHUNK * SIZE)
2507+
2508+
sock.close()
2509+
2510+
async def client(addr):
2511+
nonlocal future
2512+
future = self.loop.create_future()
2513+
2514+
reader, writer = await asyncio.open_connection(
2515+
*addr,
2516+
ssl=client_sslctx,
2517+
server_hostname='',
2518+
loop=self.loop)
2519+
writer.write(b'ping')
2520+
data = await reader.readexactly(4)
2521+
self.assertEqual(data, b'pong')
2522+
2523+
# fill write backlog in a hacky way - renegotiation won't help
2524+
ssl_protocol = writer.transport._ssl_protocol
2525+
for _ in range(SIZE):
2526+
ssl_protocol._write_backlog.append(b'x' * CHUNK)
2527+
ssl_protocol._write_buffer_size += CHUNK
2528+
2529+
try:
2530+
data = await reader.read()
2531+
self.assertEqual(data, b'')
2532+
except (BrokenPipeError, ConnectionResetError):
2533+
pass
2534+
2535+
await future
2536+
2537+
def run(meth):
2538+
def wrapper(sock):
2539+
try:
2540+
meth(sock)
2541+
except Exception as ex:
2542+
self.loop.call_soon_threadsafe(future.set_exception, ex)
2543+
else:
2544+
self.loop.call_soon_threadsafe(future.set_result, None)
2545+
return wrapper
2546+
2547+
with self.tcp_server(run(server)) as srv:
2548+
self.loop.run_until_complete(client(srv.addr))
2549+
2550+
with self.tcp_server(run(eof_server)) as srv:
2551+
self.loop.run_until_complete(client(srv.addr))
2552+
24412553

24422554
class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
24432555
pass

uvloop/sslproto.pyx

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class SSLProtocol(object):
262262
self._ssl_buffer = bytearray(256 * 1024)
263263
self._state = _UNWRAPPED
264264
self._conn_lost = 0 # Set when connection_lost called
265+
self._eof_received = False
265266

266267
# Flow Control
267268

@@ -362,13 +363,23 @@ class SSLProtocol(object):
362363
if self._loop.get_debug():
363364
aio_logger.debug("%r received EOF", self)
364365

365-
self._wakeup_waiter(ConnectionResetError)
366+
if self._state == _DO_HANDSHAKE:
367+
self._on_handshake_complete(ConnectionResetError)
368+
369+
elif self._state == _WRAPPED:
370+
self._set_state(_FLUSHING)
371+
self._do_write()
372+
self._set_state(_SHUTDOWN)
373+
self._do_shutdown()
374+
375+
elif self._state == _FLUSHING:
376+
self._do_write()
377+
self._set_state(_SHUTDOWN)
378+
self._do_shutdown()
379+
380+
elif self._state == _SHUTDOWN:
381+
self._do_shutdown()
366382

367-
if self._state != _DO_HANDSHAKE:
368-
keep_open = self._app_protocol.eof_received()
369-
if keep_open:
370-
aio_logger.warning('returning true from eof_received() '
371-
'has no effect when using ssl')
372383
finally:
373384
self._transport.close()
374385

@@ -529,6 +540,7 @@ class SSLProtocol(object):
529540
self._on_shutdown_complete(exc)
530541
else:
531542
self._process_outgoing()
543+
self._call_eof_received()
532544
self._on_shutdown_complete(None)
533545

534546
def _on_shutdown_complete(self, shutdown_exc):
@@ -627,6 +639,7 @@ class SSLProtocol(object):
627639
self._app_protocol.buffer_updated(offset)
628640
if not count:
629641
# close_notify
642+
self._call_eof_received()
630643
self._start_shutdown()
631644

632645
def _do_read__copied(self):
@@ -647,8 +660,20 @@ class SSLProtocol(object):
647660
self._app_protocol.data_received(b''.join(data))
648661
if not chunk:
649662
# close_notify
663+
self._call_eof_received()
650664
self._start_shutdown()
651665

666+
def _call_eof_received(self):
667+
try:
668+
if not self._eof_received:
669+
self._eof_received = True
670+
keep_open = self._app_protocol.eof_received()
671+
if keep_open:
672+
aio_logger.warning('returning true from eof_received() '
673+
'has no effect when using ssl')
674+
except Exception as ex:
675+
self._fatal_error(ex, 'Error calling eof_received()')
676+
652677
# Flow control for writes from APP socket
653678

654679
def _control_app_writing(self):

0 commit comments

Comments
 (0)