diff --git a/tests/test_connection.py b/tests/test_connection.py index d8d5f0f..54aeb04 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -921,3 +921,15 @@ async def handler(request): await connection.get_message() await connection.aclose() await trio.sleep(.1) + + +async def test_finalization_dropped_exception(echo_server, autojump_clock): + # Confirm that open_websocket finalization does not contribute to dropped + # exceptions as described in https://github.com/python-trio/trio/issues/1559. + with pytest.raises(ValueError): + with trio.move_on_after(1): + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False): + try: + await trio.sleep_forever() + finally: + raise ValueError diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a92ce37..3f067ad 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1,3 +1,4 @@ +import sys from collections import OrderedDict from functools import partial from ipaddress import ip_address @@ -35,6 +36,31 @@ logger = logging.getLogger('trio-websocket') +class _preserve_current_exception: + """A context manager which should surround an ``__exit__`` or + ``__aexit__`` handler or the contents of a ``finally:`` + block. It ensures that any exception that was being handled + upon entry is not masked by a `trio.Cancelled` raised within + the body of the context manager. + + https://github.com/python-trio/trio/issues/1559 + https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf + """ + __slots__ = ("_armed",) + + def __enter__(self): + self._armed = sys.exc_info()[1] is not None + + def __exit__(self, ty, value, tb): + if value is None or not self._armed: + return False + + def remove_cancels(exc): + return None if isinstance(exc, trio.Cancelled) else exc + + return trio.MultiError.filter(remove_cancels, value) is None + + @asynccontextmanager @async_generator async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, @@ -792,6 +818,10 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif :param int code: A 4-digit code number indicating the type of closure. :param str reason: An optional string describing the closure. ''' + with _preserve_current_exception(): + await self._aclose(code, reason) + + async def _aclose(self, code=1000, reason=None): if self._close_reason: # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. @@ -964,7 +994,8 @@ async def _close_stream(self): ''' Close the TCP connection. ''' self._reader_running = False try: - await self._stream.aclose() + with _preserve_current_exception(): + await self._stream.aclose() except trio.BrokenResourceError: # This means the TCP connection is already dead. pass @@ -1378,6 +1409,6 @@ async def _handle_connection(self, stream): await self._handler(request) finally: with trio.move_on_after(self._disconnect_timeout): - # aclose() will shut down the reader task even if its + # aclose() will shut down the reader task even if it's # cancelled: await connection.aclose()