From 86827d68b01c532cf9047ea78fc602f04869b9c6 Mon Sep 17 00:00:00 2001 From: kclowes Date: Thu, 29 Aug 2024 10:27:28 -0600 Subject: [PATCH] get_request_headers combomethod (#3467) * get_request_headers combomethod * Add newsfragment --- newsfragments/3467.feature.rst | 1 + tests/core/providers/test_async_http_provider.py | 4 ++-- tests/core/providers/test_http_provider.py | 4 ++-- web3/_utils/http.py | 7 +++++-- web3/providers/rpc/async_rpc.py | 12 ++++++++++-- web3/providers/rpc/rpc.py | 13 +++++++++++-- 6 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 newsfragments/3467.feature.rst diff --git a/newsfragments/3467.feature.rst b/newsfragments/3467.feature.rst new file mode 100644 index 0000000000..b309f20bc2 --- /dev/null +++ b/newsfragments/3467.feature.rst @@ -0,0 +1 @@ +HTTPProvider and AsyncHTTPProvider's get_request_headers is now available on both the class and the instance diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 876e65b3ee..cd4b2d84ef 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -94,8 +94,8 @@ async def test_async_user_provided_session() -> None: assert cached_session == session -def test_get_request_headers(): - provider = AsyncHTTPProvider() +@pytest.mark.parametrize("provider", (AsyncHTTPProvider(), AsyncHTTPProvider)) +def test_get_request_headers(provider): headers = provider.get_request_headers() assert len(headers) == 2 assert headers["Content-Type"] == "application/json" diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 3a88bef67e..82a982ded7 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -101,8 +101,8 @@ def test_user_provided_session(): assert adapter._pool_maxsize == 20 -def test_get_request_headers(): - provider = HTTPProvider() +@pytest.mark.parametrize("provider", (HTTPProvider(), HTTPProvider)) +def test_get_request_headers(provider): headers = provider.get_request_headers() assert len(headers) == 2 assert headers["Content-Type"] == "application/json" diff --git a/web3/_utils/http.py b/web3/_utils/http.py index a5d79e1d16..e3e266fcb2 100644 --- a/web3/_utils/http.py +++ b/web3/_utils/http.py @@ -1,9 +1,12 @@ DEFAULT_HTTP_TIMEOUT = 30.0 -def construct_user_agent(class_type: type) -> str: +def construct_user_agent( + module: str, + class_name: str, +) -> str: from web3 import ( __version__ as web3_version, ) - return f"web3.py/{web3_version}/{class_type.__module__}.{class_type.__qualname__}" + return f"web3.py/{web3_version}/{module}.{class_name}" diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index bd485b4583..8f723138eb 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -19,6 +19,7 @@ URI, ) from eth_utils import ( + combomethod, to_dict, ) @@ -108,10 +109,17 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]: yield "headers", self.get_request_headers() yield from self._request_kwargs.items() - def get_request_headers(self) -> Dict[str, str]: + @combomethod + def get_request_headers(cls) -> Dict[str, str]: + if isinstance(cls, AsyncHTTPProvider): + cls_name = cls.__class__.__name__ + else: + cls_name = cls.__name__ + module = cls.__module__ + return { "Content-Type": "application/json", - "User-Agent": construct_user_agent(type(self)), + "User-Agent": construct_user_agent(module, cls_name), } async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index 654d80c086..23593e8458 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -16,6 +16,7 @@ URI, ) from eth_utils import ( + combomethod, to_dict, ) import requests @@ -116,10 +117,18 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]: yield "headers", self.get_request_headers() yield from self._request_kwargs.items() - def get_request_headers(self) -> Dict[str, str]: + @combomethod + def get_request_headers(cls) -> Dict[str, str]: + if isinstance(cls, HTTPProvider): + cls_name = cls.__class__.__name__ + else: + cls_name = cls.__name__ + + module = cls.__module__ + return { "Content-Type": "application/json", - "User-Agent": construct_user_agent(type(self)), + "User-Agent": construct_user_agent(module, cls_name), } def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: