Skip to content

Commit

Permalink
Use URL.extend_query to add params in ClientRequest (#9068)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 9, 2024
1 parent 7fb1631 commit 841d00e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGES/9068.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Use :meth:`URL.extend_query() <yarl.URL.extend_query>` to extend query params (requires yarl 1.11.0+) -- by :user:`bdraco`.

If yarl is older than 1.11.0, the previous slower hand rolled version will be used.
14 changes: 9 additions & 5 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL
from yarl import URL, __version__ as yarl_version

from . import hdrs, helpers, http, multipart, payload
from .abc import AbstractStreamWriter
Expand Down Expand Up @@ -90,6 +90,7 @@


_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
_YARL_SUPPORTS_EXTEND_QUERY = tuple(map(int, yarl_version.split(".")[:2])) >= (1, 11)


def _gen_default_accept_encoding() -> str:
Expand Down Expand Up @@ -229,10 +230,13 @@ def __init__(
# assert session is not None
self._session = cast("ClientSession", session)
if params:
q = MultiDict(url.query)
url2 = url.with_query(params)
q.extend(url2.query)
url = url.with_query(q)
if _YARL_SUPPORTS_EXTEND_QUERY:
url = url.extend_query(params)
else:
q = MultiDict(url.query)
url2 = url.with_query(params)
q.extend(url2.query)
url = url.with_query(q)
self.original_url = url
self.url = url.with_fragment(None)
self.method = method.upper()
Expand Down
26 changes: 19 additions & 7 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from yarl import URL

import aiohttp
from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web
from aiohttp import Fingerprint, ServerFingerprintMismatch, client_reqrep, hdrs, web
from aiohttp.abc import AbstractResolver, ResolveResult
from aiohttp.client_exceptions import (
InvalidURL,
Expand Down Expand Up @@ -702,7 +702,10 @@ async def handler(request: web.Request) -> web.Response:
assert 200 == resp.status


async def test_params_and_query_string(aiohttp_client: AiohttpClient) -> None:
@pytest.mark.parametrize("yarl_supports_extend_query", [True, False])
async def test_params_and_query_string(
aiohttp_client: AiohttpClient, yarl_supports_extend_query: bool
) -> None:
"""Test combining params with an existing query_string."""

async def handler(request: web.Request) -> web.Response:
Expand All @@ -713,13 +716,18 @@ async def handler(request: web.Request) -> web.Response:
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)

async with client.get("/?q=abc", params="q=test&d=dog") as resp:
assert resp.status == 200
# Ensure the old path is tested for old yarl versions
with mock.patch.object(
client_reqrep, "_YARL_SUPPORTS_EXTEND_QUERY", yarl_supports_extend_query
):
async with client.get("/?q=abc", params="q=test&d=dog") as resp:
assert resp.status == 200


@pytest.mark.parametrize("params", [None, "", {}, MultiDict()])
@pytest.mark.parametrize("yarl_supports_extend_query", [True, False])
async def test_empty_params_and_query_string(
aiohttp_client: AiohttpClient, params: Any
aiohttp_client: AiohttpClient, params: Any, yarl_supports_extend_query: bool
) -> None:
"""Test combining empty params with an existing query_string."""

Expand All @@ -731,8 +739,12 @@ async def handler(request: web.Request) -> web.Response:
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)

async with client.get("/?q=abc", params=params) as resp:
assert resp.status == 200
# Ensure the old path is tested for old yarl versions
with mock.patch.object(
client_reqrep, "_YARL_SUPPORTS_EXTEND_QUERY", yarl_supports_extend_query
):
async with client.get("/?q=abc", params=params) as resp:
assert resp.status == 200


async def test_drop_params_on_redirect(aiohttp_client: AiohttpClient) -> None:
Expand Down

0 comments on commit 841d00e

Please sign in to comment.