Skip to content

Commit 88f067f

Browse files
committed
Set the close status code and reason more consistently.
Set them when the closing handshake is considered complete or aborted.
1 parent 471cbbc commit 88f067f

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ Changelog
305305
* Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no
306306
longer drops the next message.
307307

308+
* Set the close status code and reason more consistently.
309+
308310
* Improved tests.
309311

310312
2.4

websockets/protocol.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(self, *,
104104

105105
self.subprotocol = None
106106

107+
# Code and reason must be set when the closing handshake completes.
107108
self.close_code = None
108109
self.close_reason = ''
109110

@@ -161,7 +162,6 @@ def close(self, code=1000, reason=''):
161162
if self.state == OPEN:
162163
# 7.1.2. Start the WebSocket Closing Handshake
163164
# 7.1.3. The WebSocket Closing Handshake is Started
164-
self.close_code, self.close_reason = code, reason
165165
frame_data = serialize_close(code, reason)
166166
yield from self.write_frame(OP_CLOSE, frame_data)
167167

@@ -350,11 +350,13 @@ def read_data_frame(self, max_size):
350350
frame = yield from self.read_frame(max_size)
351351
# 5.5. Control Frames
352352
if frame.opcode == OP_CLOSE:
353-
self.close_code, self.close_reason = parse_close(frame.data)
353+
# Make sure the close frame is valid before echoing it.
354+
code, reason = parse_close(frame.data)
354355
if self.state == OPEN:
355356
# 7.1.3. The WebSocket Closing Handshake is Started
356357
yield from self.write_frame(OP_CLOSE, frame.data)
357358
if not self.closing_handshake.done():
359+
self.close_code, self.close_reason = code, reason
358360
self.closing_handshake.set_result(True)
359361
return
360362
elif frame.opcode == OP_PING:
@@ -443,10 +445,6 @@ def close_connection(self, force=False):
443445

444446
@asyncio.coroutine
445447
def fail_connection(self, code=1011, reason=''):
446-
# Losing the connection usually results in a protocol error.
447-
# Preserve the original error code in this case.
448-
if self.close_code != 1006:
449-
self.close_code, self.close_reason = code, reason
450448
# 7.1.7. Fail the WebSocket Connection
451449
logger.info("Failing the WebSocket connection: %d %s", code, reason)
452450
if self.state == OPEN:
@@ -458,6 +456,7 @@ def fail_connection(self, code=1011, reason=''):
458456
frame_data = serialize_close(code, reason)
459457
yield from self.write_frame(OP_CLOSE, frame_data)
460458
if not self.closing_handshake.done():
459+
self.close_code, self.close_reason = code, reason
461460
self.closing_handshake.set_result(False)
462461
yield from self.close_connection()
463462

@@ -472,8 +471,9 @@ def client_connected(self, reader, writer):
472471
def connection_lost(self, exc):
473472
# 7.1.4. The WebSocket Connection is Closed
474473
self.state = CLOSED
474+
if not self.closing_handshake.done():
475+
self.close_code, self.close_reason = 1006, ''
476+
self.closing_handshake.set_result(False)
475477
if not self.connection_closed.done():
476478
self.connection_closed.set_result(None)
477-
if self.close_code is None:
478-
self.close_code = 1006
479479
super().connection_lost(exc)

websockets/test_protocol.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def test_close_handshake_in_fragmented_text(self):
391391
self.receive_frame(Frame(True, OP_CLOSE, b''))
392392
self.receive_eof()
393393
self.assertIsNone(self.loop.run_until_complete(self.protocol.recv()))
394-
self.assertConnectionClosed(1002, '')
394+
self.assertConnectionClosed(1005, '')
395395

396396
def test_connection_close_in_fragmented_text(self):
397397
self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
@@ -458,20 +458,19 @@ def test_close_handshake_timeout(self):
458458
# response. The server will stop waiting for the close frame and
459459
# timeout.
460460
self.loop.run_until_complete(self.protocol.close(reason='close'))
461-
self.assertConnectionClosed(1000, 'close')
461+
self.assertConnectionClosed(1006, '')
462462

463463
def test_client_close_race_with_failing_connection(self):
464464
self.make_drain_slow()
465465

466466
# Fail the connection while answering a close frame from the client.
467467
self.loop.call_soon(self.receive_frame, self.client_close)
468-
fail_connection = self.protocol.fail_connection(1000, 'server')
469-
self.loop.call_later(MS, self.async, fail_connection)
468+
self.loop.call_later(MS, self.async, self.protocol.fail_connection())
470469
next_message = self.loop.run_until_complete(self.protocol.recv())
471470

472471
self.assertIsNone(next_message)
473-
# The connection was closed before the close frame could be sent.
474-
self.assertConnectionClosed(1006, '')
472+
# The closing handshake was completed by fail_connection.
473+
self.assertConnectionClosed(1011, '')
475474
self.assertOneFrameSent(*self.client_close)
476475

477476
def test_close_protocol_error(self):
@@ -581,7 +580,7 @@ def test_close_handshake_timeout(self):
581580
# stop waiting for the close frame and timeout, then stop waiting
582581
# for the connection close and timeout again.
583582
self.loop.run_until_complete(self.protocol.close(reason='close'))
584-
self.assertConnectionClosed(1000, 'close')
583+
self.assertConnectionClosed(1006, '')
585584

586585
def test_eof_received_timeout(self):
587586
# Timeout is expected in 10ms.
@@ -600,14 +599,13 @@ def test_server_close_race_with_failing_connection(self):
600599

601600
# Fail the connection while answering a close frame from the server.
602601
self.loop.call_soon(self.receive_frame, self.server_close)
603-
fail_connection = self.protocol.fail_connection(1000, 'client')
604-
self.loop.call_later(MS, self.async, fail_connection)
602+
self.loop.call_later(MS, self.async, self.protocol.fail_connection())
605603
self.loop.call_later(2 * MS, self.receive_eof)
606604
next_message = self.loop.run_until_complete(self.protocol.recv())
607605

608606
self.assertIsNone(next_message)
609-
# The connection was closed before the close frame could be sent.
610-
self.assertConnectionClosed(1006, '')
607+
# The closing handshake was completed by fail_connection.
608+
self.assertConnectionClosed(1011, '')
611609
self.assertOneFrameSent(*self.server_close)
612610

613611
def test_close_protocol_error(self):

0 commit comments

Comments
 (0)