Skip to content

Commit

Permalink
[PR #8632/b2691f2 backport][3.10] Fix connecting to npipe://, tcp://,…
Browse files Browse the repository at this point in the history
… and unix:// urls (#8637)

Co-authored-by: Sam Bull <git@sambull.org>
  • Loading branch information
bdraco and Dreamsorcerer authored Aug 7, 2024
1 parent bf83dbe commit 72f41aa
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGES/8632.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`.
10 changes: 4 additions & 6 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
)
from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse
from .connector import (
HTTP_AND_EMPTY_SCHEMA_SET,
BaseConnector as BaseConnector,
NamedPipeConnector as NamedPipeConnector,
TCPConnector as TCPConnector,
Expand Down Expand Up @@ -209,9 +210,6 @@ class ClientTimeout:

# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
WS_SCHEMA_SET = frozenset({"ws", "wss"})
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET

_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]
Expand Down Expand Up @@ -517,7 +515,8 @@ async def _request(
except ValueError as e:
raise InvalidUrlClientError(str_or_url) from e

if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
assert self._connector is not None
if url.scheme not in self._connector.allowed_protocol_schema_set:
raise NonHttpUrlClientError(url)

skip_headers = set(self._skip_auto_headers)
Expand Down Expand Up @@ -655,7 +654,6 @@ async def _request(
real_timeout.connect,
ceil_threshold=real_timeout.ceil_threshold,
):
assert self._connector is not None
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
Expand Down Expand Up @@ -752,7 +750,7 @@ async def _request(
) from e

scheme = parsed_redirect_url.scheme
if scheme not in HTTP_SCHEMA_SET:
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
resp.close()
raise NonHttpUrlRedirectClientError(r_url)
elif not scheme:
Expand Down
16 changes: 16 additions & 0 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
SSLContext = object # type: ignore[misc,assignment]


EMPTY_SCHEMA_SET = frozenset({""})
HTTP_SCHEMA_SET = frozenset({"http", "https"})
WS_SCHEMA_SET = frozenset({"ws", "wss"})

HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET


__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")


Expand Down Expand Up @@ -211,6 +219,8 @@ class BaseConnector:
# abort transport after 2 seconds (cleanup broken connections)
_cleanup_closed_period = 2.0

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET

def __init__(
self,
*,
Expand Down Expand Up @@ -760,6 +770,8 @@ class TCPConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})

def __init__(
self,
*,
Expand Down Expand Up @@ -1458,6 +1470,8 @@ class UnixConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})

def __init__(
self,
path: str,
Expand Down Expand Up @@ -1514,6 +1528,8 @@ class NamedPipeConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})

def __init__(
self,
path: str,
Expand Down
75 changes: 70 additions & 5 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
import json
from http.cookies import SimpleCookie
from typing import Any, List
from typing import Any, Awaitable, Callable, List
from unittest import mock
from uuid import uuid4

Expand All @@ -16,10 +16,12 @@
import aiohttp
from aiohttp import client, hdrs, web
from aiohttp.client import ClientSession
from aiohttp.client_proto import ResponseHandler
from aiohttp.client_reqrep import ClientRequest
from aiohttp.connector import BaseConnector, TCPConnector
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
from aiohttp.helpers import DEBUG
from aiohttp.test_utils import make_mocked_coro
from aiohttp.tracing import Trace


@pytest.fixture
Expand Down Expand Up @@ -487,15 +489,17 @@ async def test_ws_connect_allowed_protocols(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example.com")
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# BaseConnector allows all high level protocols by default
connector = BaseConnector()

session = await create_session(request_class=req_factory)
session = await create_session(connector=connector, request_class=req_factory)

connections = []
original_connect = session._connector.connect
Expand All @@ -515,7 +519,68 @@ async def create_connection(req, traces, timeout):
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example.com")
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
for c in connections:
c.close()
c.__del__()

await session.close()


@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"])
async def test_ws_connect_unix_socket_allowed_protocols(
create_session: Callable[..., Awaitable[ClientSession]],
create_mocked_conn: Callable[[], ResponseHandler],
protocol: str,
ws_key: bytes,
key_data: bytes,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# UnixConnector allows all high level protocols by default and unix sockets
session = await create_session(
connector=UnixConnector(path=""), request_class=req_factory
)

connections = []
assert session._connector is not None
original_connect = session._connector.connect

async def connect(
req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout
) -> Connection:
conn = await original_connect(req, traces, timeout)
connections.append(conn)
return conn

async def create_connection(
req: object, traces: object, timeout: object
) -> ResponseHandler:
return create_mocked_conn()

connector = session._connector
with mock.patch.object(connector, "connect", connect), mock.patch.object(
connector, "_create_connection", create_connection
), mock.patch.object(connector, "_release"), mock.patch(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
Expand Down
49 changes: 46 additions & 3 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,19 @@ async def test_tcp_connector_ctor() -> None:
assert conn.family == 0


async def test_tcp_connector_ctor_fingerprint_valid(loop) -> None:
async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"}


async def test_invalid_ssl_param() -> None:
with pytest.raises(TypeError):
aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type]


async def test_tcp_connector_ctor_fingerprint_valid(
loop: asyncio.AbstractEventLoop,
) -> None:
valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest())
conn = aiohttp.TCPConnector(ssl=valid, loop=loop)
assert conn._ssl is valid
Expand Down Expand Up @@ -1639,8 +1651,23 @@ async def test_ctor_with_default_loop(loop) -> None:
assert loop is conn._loop


async def test_connect_with_limit(loop, key) -> None:
proto = mock.Mock()
async def test_base_connector_allows_high_level_protocols(
loop: asyncio.AbstractEventLoop,
) -> None:
conn = aiohttp.BaseConnector()
assert conn.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
}


async def test_connect_with_limit(
loop: asyncio.AbstractEventLoop, key: ConnectionKey
) -> None:
proto = create_mocked_conn(loop)
proto.is_connected.return_value = True

req = ClientRequest(
Expand Down Expand Up @@ -2412,6 +2439,14 @@ async def handler(request):

connector = aiohttp.UnixConnector(unix_sockname)
assert unix_sockname == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"unix",
}

session = client.ClientSession(connector=connector)
r = await session.get(url)
Expand All @@ -2437,6 +2472,14 @@ async def handler(request):

connector = aiohttp.NamedPipeConnector(pipe_name)
assert pipe_name == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"npipe",
}

session = client.ClientSession(connector=connector)
r = await session.get(url)
Expand Down

0 comments on commit 72f41aa

Please sign in to comment.