From eb685564f339679d719eb748dc615fdb0d97f604 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 18 Sep 2024 18:28:19 +0100 Subject: [PATCH] Add and use ClientConnectionResetError (#9137) (#9194) (cherry picked from commit f95bcaf4e0b2344d09df8dbb565150dcb4e73c0f) --- CHANGES/9137.bugfix.rst | 2 ++ aiohttp/__init__.py | 2 ++ aiohttp/base_protocol.py | 3 ++- aiohttp/client.py | 2 ++ aiohttp/client_exceptions.py | 9 +++++++-- aiohttp/http_websocket.py | 5 +++-- aiohttp/http_writer.py | 3 ++- docs/client_reference.rst | 6 ++++++ tests/test_client_ws.py | 15 ++++++++++----- tests/test_client_ws_functional.py | 4 ++-- tests/test_http_writer.py | 10 ++++++---- 11 files changed, 44 insertions(+), 17 deletions(-) create mode 100644 CHANGES/9137.bugfix.rst diff --git a/CHANGES/9137.bugfix.rst b/CHANGES/9137.bugfix.rst new file mode 100644 index 00000000000..d99802095bd --- /dev/null +++ b/CHANGES/9137.bugfix.rst @@ -0,0 +1,2 @@ +Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError` +will now throw this -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 15602a7dc85..c5f13c6dc49 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -6,6 +6,7 @@ from .client import ( BaseConnector, ClientConnectionError, + ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorError, ClientConnectorSSLError, @@ -125,6 +126,7 @@ # client "BaseConnector", "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index dc1f24f99cd..2fc2fa65885 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay @@ -85,7 +86,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: async def _drain_helper(self) -> None: if not self.connected: - raise ConnectionResetError("Connection lost") + raise ClientConnectionResetError("Connection lost") if not self._paused: return waiter = self._drain_waiter diff --git a/aiohttp/client.py b/aiohttp/client.py index d59d03fa5ec..443335c6061 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -40,6 +40,7 @@ from .abc import AbstractCookieJar from .client_exceptions import ( ClientConnectionError, + ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorError, ClientConnectorSSLError, @@ -106,6 +107,7 @@ __all__ = ( # client_exceptions "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 36bb6d1c0d8..94991c42477 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -6,7 +6,6 @@ from multidict import MultiMapping -from .http_parser import RawResponseMessage from .typedefs import StrOrURL try: @@ -19,12 +18,14 @@ if TYPE_CHECKING: from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo + from .http_parser import RawResponseMessage else: - RequestInfo = ClientResponse = ConnectionKey = None + RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None __all__ = ( "ClientError", "ClientConnectionError", + "ClientConnectionResetError", "ClientOSError", "ClientConnectorError", "ClientProxyConnectionError", @@ -159,6 +160,10 @@ class ClientConnectionError(ClientError): """Base class for client socket errors.""" +class ClientConnectionResetError(ClientConnectionError, ConnectionResetError): + """ConnectionResetError""" + + class ClientOSError(ClientConnectionError, OSError): """OSError error.""" diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 9d03d2773c7..c6521695d94 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -25,6 +25,7 @@ ) from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS, set_exception from .streams import DataQueue @@ -624,7 +625,7 @@ async def _send_frame( ) -> None: """Send a frame over the websocket with message as its payload.""" if self._closing and not (opcode & WSMsgType.CLOSE): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") # RSV are the reserved bits in the frame header. They are used to # indicate that the frame is using an extension. @@ -719,7 +720,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor: def _write(self, data: bytes) -> None: if self.transport is None or self.transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") self.transport.write(data) async def pong(self, message: Union[bytes, str] = b"") -> None: diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index d6b02e6f566..f54fa0f0774 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -8,6 +8,7 @@ from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS @@ -72,7 +73,7 @@ def _write(self, chunk: bytes) -> None: self.output_size += size transport = self.transport if not self._protocol.connected or transport is None or transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) async def write( diff --git a/docs/client_reference.rst b/docs/client_reference.rst index a16443f275e..7f88fda14c9 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -2225,6 +2225,10 @@ Connection errors Derived from :exc:`ClientError` +.. class:: ClientConnectionResetError + + Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError` + .. class:: ClientOSError Subset of connection errors that are initiated by an :exc:`OSError` @@ -2311,6 +2315,8 @@ Hierarchy of exceptions * :exc:`ClientConnectionError` + * :exc:`ClientConnectionResetError` + * :exc:`ClientOSError` * :exc:`ClientConnectorError` diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 31ec7576c97..afe7983648f 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -2,14 +2,13 @@ import base64 import hashlib import os -from typing import Any +from typing import Any, Type from unittest import mock import pytest import aiohttp -from aiohttp import client, hdrs -from aiohttp.client_exceptions import ServerDisconnectedError +from aiohttp import ClientConnectionResetError, ServerDisconnectedError, client, hdrs from aiohttp.http import WS_KEY from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro @@ -508,7 +507,13 @@ async def test_close_exc2(loop, ws_key, key_data) -> None: await resp.close() -async def test_send_data_after_close(ws_key, key_data, loop) -> None: +@pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError)) +async def test_send_data_after_close( + exc: Type[Exception], + ws_key: bytes, + key_data: bytes, + loop: asyncio.AbstractEventLoop, +) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { @@ -533,7 +538,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None: (resp.send_bytes, (b"b",)), (resp.send_json, ({},)), ): - with pytest.raises(ConnectionResetError): + with pytest.raises(exc): # Verify exc can be caught with both classes await meth(*args) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 30da0dca802..0a8008f07ca 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -6,7 +6,7 @@ import pytest import aiohttp -from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web +from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web from aiohttp.client_ws import ClientWSTimeout from aiohttp.http import WSCloseCode from aiohttp.pytest_plugin import AiohttpClient @@ -681,7 +681,7 @@ async def handler(request: web.Request) -> NoReturn: # would cancel the heartbeat task and we wouldn't get a ping assert resp._conn is not None with mock.patch.object( - resp._conn.transport, "write", side_effect=ConnectionResetError + resp._conn.transport, "write", side_effect=ClientConnectionResetError ), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping: await resp.receive() ping_count = ping.call_count diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index db50ad65f67..ed853c8744a 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -5,7 +5,7 @@ import pytest from multidict import CIMultiDict -from aiohttp import http +from aiohttp import ClientConnectionResetError, http from aiohttp.test_utils import make_mocked_coro @@ -232,12 +232,12 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None: await msg.write(b"Before closing") transport.is_closing.return_value = True - with pytest.raises(ConnectionResetError): + with pytest.raises(ClientConnectionResetError): await msg.write(b"After closing") async def test_write_to_closed_transport(protocol, transport, loop) -> None: - """Test that writing to a closed transport raises ConnectionResetError. + """Test that writing to a closed transport raises ClientConnectionResetError. The StreamWriter checks to see if protocol.transport is None before writing to the transport. If it is None, it raises ConnectionResetError. @@ -247,7 +247,9 @@ async def test_write_to_closed_transport(protocol, transport, loop) -> None: await msg.write(b"Before transport close") protocol.transport = None - with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"): + with pytest.raises( + ClientConnectionResetError, match="Cannot write to closing transport" + ): await msg.write(b"After transport closed")