-
Notifications
You must be signed in to change notification settings - Fork 11
feat(http): switch httpx for niquests in order to stabilize network I/O #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,11 @@ | |
from enum import IntEnum | ||
from json import loads | ||
from os import environ | ||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit | ||
|
||
from httpx import AsyncClient, Client, Headers, Limits, ReadTimeout, Request, Response | ||
from httpx import __version__ as httpx_version | ||
from niquests import AsyncSession, ReadTimeout, Request, Response, Session | ||
from niquests import __version__ as niquests_version | ||
from niquests.structures import CaseInsensitiveDict | ||
from starlette.requests import HTTPConnection | ||
|
||
from . import options | ||
|
@@ -49,6 +51,13 @@ class ServerVersion(typing.TypedDict): | |
"""Indicates if the subscription has extended support""" | ||
|
||
|
||
@dataclass | ||
class Limits: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added that for pure backward compatibility. Niquests as no such thing as a "Limits" object. |
||
max_keepalive_connections: int | None = 20 | ||
max_connections: int | None = 100 | ||
keepalive_expiry: int | float | None = 5 | ||
|
||
|
||
@dataclass | ||
class RuntimeOptions: | ||
xdebug_session: str | ||
|
@@ -134,11 +143,11 @@ def __init__(self, **kwargs): | |
|
||
|
||
class NcSessionBase(ABC): | ||
adapter: AsyncClient | Client | ||
adapter_dav: AsyncClient | Client | ||
adapter: AsyncSession | Session | ||
adapter_dav: AsyncSession | Session | ||
cfg: BasicConfig | ||
custom_headers: dict | ||
response_headers: Headers | ||
response_headers: CaseInsensitiveDict | ||
_user: str | ||
_capabilities: dict | ||
|
||
|
@@ -150,7 +159,7 @@ def __init__(self, **kwargs): | |
self.limits = Limits(max_keepalive_connections=20, max_connections=20, keepalive_expiry=60.0) | ||
self.init_adapter() | ||
self.init_adapter_dav() | ||
self.response_headers = Headers() | ||
self.response_headers = CaseInsensitiveDict() | ||
self._ocs_regexp = re.compile(r"/ocs/v[12]\.php/|/apps/groupfolders/") | ||
|
||
def init_adapter(self, restart=False) -> None: | ||
|
@@ -172,7 +181,7 @@ def init_adapter_dav(self, restart=False) -> None: | |
self.adapter_dav.cookies.set("XDEBUG_SESSION", options.XDEBUG_SESSION) | ||
|
||
@abstractmethod | ||
def _create_adapter(self, dav: bool = False) -> AsyncClient | Client: | ||
def _create_adapter(self, dav: bool = False) -> AsyncSession | Session: | ||
pass # pragma: no cover | ||
|
||
@property | ||
|
@@ -187,8 +196,8 @@ def ae_url_v2(self) -> str: | |
|
||
|
||
class NcSessionBasic(NcSessionBase, ABC): | ||
adapter: Client | ||
adapter_dav: Client | ||
adapter: Session | ||
adapter_dav: Session | ||
|
||
def ocs( | ||
self, | ||
|
@@ -206,9 +215,7 @@ def ocs( | |
info = f"request: {method} {path}" | ||
nested_req = kwargs.pop("nested_req", False) | ||
try: | ||
response = self.adapter.request( | ||
method, path, content=content, json=json, params=params, files=files, **kwargs | ||
) | ||
response = self.adapter.request(method, path, data=content, json=json, params=params, files=files, **kwargs) | ||
except ReadTimeout: | ||
raise NextcloudException(408, info=info) from None | ||
|
||
|
@@ -281,18 +288,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]: | |
return { | ||
"base_url": self.cfg.dav_endpoint, | ||
"timeout": self.cfg.options.timeout_dav, | ||
"event_hooks": {"request": [], "response": [self._response_event]}, | ||
"event_hooks": {"pre_request": [], "response": [self._response_event]}, | ||
} | ||
return { | ||
"base_url": self.cfg.endpoint, | ||
"timeout": self.cfg.options.timeout, | ||
"event_hooks": {"request": [self._request_event_ocs], "response": [self._response_event]}, | ||
"event_hooks": {"pre_request": [self._request_event_ocs], "response": [self._response_event]}, | ||
} | ||
|
||
def _request_event_ocs(self, request: Request) -> None: | ||
str_url = str(request.url) | ||
if re.search(self._ocs_regexp, str_url) is not None: # this is OCS call | ||
request.url = request.url.copy_merge_params({"format": "json"}) | ||
request.url = patch_param(request.url, "format", "json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. niquests does not have an "URL" object. we are entitled to strict backward compatibility with requests, and bringing such object would not be taken gently by our users. |
||
request.headers["Accept"] = "application/json" | ||
|
||
def _response_event(self, response: Response) -> None: | ||
|
@@ -305,15 +312,15 @@ def _response_event(self, response: Response) -> None: | |
|
||
def download2fp(self, url_path: str, fp, dav: bool, params=None, **kwargs): | ||
adapter = self.adapter_dav if dav else self.adapter | ||
with adapter.stream("GET", url_path, params=params, headers=kwargs.get("headers")) as response: | ||
with adapter.get(url_path, params=params, headers=kwargs.get("headers"), stream=True) as response: | ||
check_error(response) | ||
for data_chunk in response.iter_bytes(chunk_size=kwargs.get("chunk_size", 5 * 1024 * 1024)): | ||
for data_chunk in response.iter_raw(chunk_size=kwargs.get("chunk_size", -1)): | ||
fp.write(data_chunk) | ||
|
||
|
||
class AsyncNcSessionBasic(NcSessionBase, ABC): | ||
adapter: AsyncClient | ||
adapter_dav: AsyncClient | ||
adapter: AsyncSession | ||
adapter_dav: AsyncSession | ||
|
||
async def ocs( | ||
self, | ||
|
@@ -332,7 +339,7 @@ async def ocs( | |
nested_req = kwargs.pop("nested_req", False) | ||
try: | ||
response = await self.adapter.request( | ||
method, path, content=content, json=json, params=params, files=files, **kwargs | ||
method, path, data=content, json=json, params=params, files=files, **kwargs | ||
) | ||
except ReadTimeout: | ||
raise NextcloudException(408, info=info) from None | ||
|
@@ -350,7 +357,7 @@ async def ocs( | |
and ocs_meta["statuscode"] == 403 | ||
and str(ocs_meta["message"]).lower().find("password confirmation is required") != -1 | ||
): | ||
await self.adapter.aclose() | ||
await self.adapter.close() | ||
self.init_adapter(restart=True) | ||
return await self.ocs( | ||
method, path, **kwargs, content=content, json=json, params=params, nested_req=True | ||
|
@@ -408,18 +415,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]: | |
return { | ||
"base_url": self.cfg.dav_endpoint, | ||
"timeout": self.cfg.options.timeout_dav, | ||
"event_hooks": {"request": [], "response": [self._response_event]}, | ||
"event_hooks": {"pre_request": [], "response": [self._response_event]}, | ||
} | ||
return { | ||
"base_url": self.cfg.endpoint, | ||
"timeout": self.cfg.options.timeout, | ||
"event_hooks": {"request": [self._request_event_ocs], "response": [self._response_event]}, | ||
"event_hooks": {"pre_request": [self._request_event_ocs], "response": [self._response_event]}, | ||
} | ||
|
||
async def _request_event_ocs(self, request: Request) -> None: | ||
str_url = str(request.url) | ||
if re.search(self._ocs_regexp, str_url) is not None: # this is OCS call | ||
request.url = request.url.copy_merge_params({"format": "json"}) | ||
request.url = patch_param(request.url, "format", "json") | ||
request.headers["Accept"] = "application/json" | ||
|
||
async def _response_event(self, response: Response) -> None: | ||
|
@@ -432,10 +439,12 @@ async def _response_event(self, response: Response) -> None: | |
|
||
async def download2fp(self, url_path: str, fp, dav: bool, params=None, **kwargs): | ||
adapter = self.adapter_dav if dav else self.adapter | ||
async with adapter.stream("GET", url_path, params=params, headers=kwargs.get("headers")) as response: | ||
check_error(response) | ||
async for data_chunk in response.aiter_bytes(chunk_size=kwargs.get("chunk_size", 5 * 1024 * 1024)): | ||
fp.write(data_chunk) | ||
response = await adapter.get(url_path, params=params, headers=kwargs.get("headers"), stream=True) | ||
|
||
check_error(response) | ||
|
||
async for data_chunk in await response.iter_raw(chunk_size=kwargs.get("chunk_size", -1)): | ||
fp.write(data_chunk) | ||
|
||
|
||
class NcSession(NcSessionBasic): | ||
|
@@ -445,15 +454,20 @@ def __init__(self, **kwargs): | |
self.cfg = Config(**kwargs) | ||
super().__init__() | ||
|
||
def _create_adapter(self, dav: bool = False) -> AsyncClient | Client: | ||
return Client( | ||
follow_redirects=True, | ||
limits=self.limits, | ||
verify=self.cfg.options.nc_cert, | ||
**self._get_adapter_kwargs(dav), | ||
auth=self.cfg.auth, | ||
def _create_adapter(self, dav: bool = False) -> AsyncSession | Session: | ||
session_kwargs = self._get_adapter_kwargs(dav) | ||
hooks = session_kwargs.pop("event_hooks") | ||
|
||
session = Session( | ||
keepalive_delay=self.limits.keepalive_expiry, pool_maxsize=self.limits.max_connections, **session_kwargs | ||
) | ||
|
||
session.auth = self.cfg.auth | ||
session.verify = self.cfg.options.nc_cert | ||
session.hooks.update(hooks) | ||
|
||
return session | ||
|
||
|
||
class AsyncNcSession(AsyncNcSessionBasic): | ||
cfg: Config | ||
|
@@ -462,21 +476,28 @@ def __init__(self, **kwargs): | |
self.cfg = Config(**kwargs) | ||
super().__init__() | ||
|
||
def _create_adapter(self, dav: bool = False) -> AsyncClient | Client: | ||
return AsyncClient( | ||
follow_redirects=True, | ||
limits=self.limits, | ||
verify=self.cfg.options.nc_cert, | ||
**self._get_adapter_kwargs(dav), | ||
auth=self.cfg.auth, | ||
def _create_adapter(self, dav: bool = False) -> AsyncSession | Session: | ||
session_kwargs = self._get_adapter_kwargs(dav) | ||
hooks = session_kwargs.pop("event_hooks") | ||
|
||
session = AsyncSession( | ||
keepalive_delay=self.limits.keepalive_expiry, | ||
pool_maxsize=self.limits.max_connections, | ||
**session_kwargs, | ||
) | ||
|
||
session.verify = self.cfg.options.nc_cert | ||
session.auth = self.cfg.auth | ||
session.hooks.update(hooks) | ||
|
||
return session | ||
|
||
|
||
class NcSessionAppBasic(ABC): | ||
cfg: AppConfig | ||
_user: str | ||
adapter: AsyncClient | Client | ||
adapter_dav: AsyncClient | Client | ||
adapter: AsyncSession | Session | ||
adapter_dav: AsyncSession | Session | ||
|
||
def __init__(self, **kwargs): | ||
self.cfg = AppConfig(**kwargs) | ||
|
@@ -505,22 +526,29 @@ def sign_check(self, request: HTTPConnection) -> str: | |
class NcSessionApp(NcSessionAppBasic, NcSessionBasic): | ||
cfg: AppConfig | ||
|
||
def _create_adapter(self, dav: bool = False) -> AsyncClient | Client: | ||
r = self._get_adapter_kwargs(dav) | ||
r["event_hooks"]["request"].append(self._add_auth) | ||
return Client( | ||
follow_redirects=True, | ||
limits=self.limits, | ||
verify=self.cfg.options.nc_cert, | ||
**r, | ||
headers={ | ||
"AA-VERSION": self.cfg.aa_version, | ||
"EX-APP-ID": self.cfg.app_name, | ||
"EX-APP-VERSION": self.cfg.app_version, | ||
"user-agent": f"ExApp/{self.cfg.app_name}/{self.cfg.app_version} (httpx/{httpx_version})", | ||
}, | ||
def _create_adapter(self, dav: bool = False) -> AsyncSession | Session: | ||
session_kwargs = self._get_adapter_kwargs(dav) | ||
session_kwargs["event_hooks"]["pre_request"].append(self._add_auth) | ||
|
||
hooks = session_kwargs.pop("event_hooks") | ||
|
||
session = Session( | ||
keepalive_delay=self.limits.keepalive_expiry, | ||
pool_maxsize=self.limits.max_connections, | ||
**session_kwargs, | ||
) | ||
|
||
session.verify = self.cfg.options.nc_cert | ||
session.headers = { | ||
"AA-VERSION": self.cfg.aa_version, | ||
"EX-APP-ID": self.cfg.app_name, | ||
"EX-APP-VERSION": self.cfg.app_version, | ||
"user-agent": f"ExApp/{self.cfg.app_name}/{self.cfg.app_version} (niquests/{niquests_version})", | ||
} | ||
session.hooks.update(hooks) | ||
|
||
return session | ||
|
||
def _add_auth(self, request: Request): | ||
request.headers.update( | ||
{"AUTHORIZATION-APP-API": b64encode(f"{self._user}:{self.cfg.app_secret}".encode("UTF=8"))} | ||
|
@@ -530,23 +558,39 @@ def _add_auth(self, request: Request): | |
class AsyncNcSessionApp(NcSessionAppBasic, AsyncNcSessionBasic): | ||
cfg: AppConfig | ||
|
||
def _create_adapter(self, dav: bool = False) -> AsyncClient | Client: | ||
r = self._get_adapter_kwargs(dav) | ||
r["event_hooks"]["request"].append(self._add_auth) | ||
return AsyncClient( | ||
follow_redirects=True, | ||
limits=self.limits, | ||
verify=self.cfg.options.nc_cert, | ||
**r, | ||
headers={ | ||
"AA-VERSION": self.cfg.aa_version, | ||
"EX-APP-ID": self.cfg.app_name, | ||
"EX-APP-VERSION": self.cfg.app_version, | ||
"User-Agent": f"ExApp/{self.cfg.app_name}/{self.cfg.app_version} (httpx/{httpx_version})", | ||
}, | ||
def _create_adapter(self, dav: bool = False) -> AsyncSession | Session: | ||
session_kwargs = self._get_adapter_kwargs(dav) | ||
session_kwargs["event_hooks"]["pre_request"].append(self._add_auth) | ||
|
||
hooks = session_kwargs.pop("event_hooks") | ||
|
||
session = AsyncSession( | ||
keepalive_delay=self.limits.keepalive_expiry, | ||
pool_maxsize=self.limits.max_connections, | ||
**session_kwargs, | ||
) | ||
session.verify = self.cfg.options.nc_cert | ||
session.headers = { | ||
"AA-VERSION": self.cfg.aa_version, | ||
"EX-APP-ID": self.cfg.app_name, | ||
"EX-APP-VERSION": self.cfg.app_version, | ||
"User-Agent": f"ExApp/{self.cfg.app_name}/{self.cfg.app_version} (niquests/{niquests_version})", | ||
} | ||
session.hooks.update(hooks) | ||
|
||
return session | ||
|
||
async def _add_auth(self, request: Request): | ||
request.headers.update( | ||
{"AUTHORIZATION-APP-API": b64encode(f"{self._user}:{self.cfg.app_secret}".encode("UTF=8"))} | ||
) | ||
|
||
|
||
def patch_param(url: str, key: str, value: str) -> str: | ||
parts = urlsplit(url) | ||
query = dict(parse_qsl(parts.query, keep_blank_values=True)) | ||
query[key] = value | ||
|
||
new_query = urlencode(query, doseq=True) | ||
|
||
return urlunsplit((parts.scheme, parts.netloc, parts.path, new_query, parts.fragment)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no strict equivalent at niquests
httpx.codes
but the patch will act exactly as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a note to be considered is that http/2 and http/3 no longer accept to return "response phrase" or also known as "reason" attached to the status code.
but you could already be aware of that.