From 684746b52e1e00c1124dea92b72aa775449c26f8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 17 Sep 2024 17:22:28 +0200 Subject: [PATCH] Cache construction of middleware handlers (#9158) (cherry picked from commit bf022b390ee97172d20c0f554b87868d0ef4d938) --- CHANGES/9158.misc.rst | 3 +++ aiohttp/web_app.py | 42 +++++++++++++++++++++++++++--------- tests/test_web_middleware.py | 22 ++++++++++++------- 3 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 CHANGES/9158.misc.rst diff --git a/CHANGES/9158.misc.rst b/CHANGES/9158.misc.rst new file mode 100644 index 00000000000..8d87623c056 --- /dev/null +++ b/CHANGES/9158.misc.rst @@ -0,0 +1,3 @@ +Significantly improved performance of middlewares -- by :user:`bdraco`. + +The construction of the middleware wrappers is now cached and is built once per handler instead of on every request. diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 3510bffda60..b8768064507 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -1,7 +1,7 @@ import asyncio import logging import warnings -from functools import partial, update_wrapper +from functools import lru_cache, partial, update_wrapper from typing import ( TYPE_CHECKING, Any, @@ -38,7 +38,7 @@ from .http_parser import RawRequestMessage from .log import web_logger from .streams import StreamReader -from .typedefs import Middleware +from .typedefs import Handler, Middleware from .web_exceptions import NotAppKeyWarning from .web_log import AccessLogger from .web_middlewares import _fix_request_current_app @@ -79,6 +79,17 @@ _Resource = TypeVar("_Resource", bound=AbstractResource) +@lru_cache(None) +def _build_middlewares( + handler: Handler, apps: Tuple["Application", ...] +) -> Callable[[Request], Awaitable[StreamResponse]]: + """Apply middlewares to handler.""" + for app in apps: + for m, _ in app._middlewares_handlers: # type: ignore[union-attr] + handler = update_wrapper(partial(m, handler=handler), handler) # type: ignore[misc] + return handler + + class Application(MutableMapping[Union[str, AppKey[Any]], Any]): ATTRS = frozenset( [ @@ -89,6 +100,7 @@ class Application(MutableMapping[Union[str, AppKey[Any]], Any]): "_handler_args", "_middlewares", "_middlewares_handlers", + "_has_legacy_middlewares", "_run_middlewares", "_state", "_frozen", @@ -143,6 +155,7 @@ def __init__( self._middlewares_handlers: _MiddlewaresHandlers = None # initialized on freezing self._run_middlewares: Optional[bool] = None + self._has_legacy_middlewares: bool = True self._state: Dict[Union[AppKey[Any], str], object] = {} self._frozen = False @@ -228,6 +241,9 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: return iter(self._state) + def __hash__(self) -> int: + return id(self) + @overload # type: ignore[override] def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... @@ -284,6 +300,9 @@ def pre_freeze(self) -> None: self._on_shutdown.freeze() self._on_cleanup.freeze() self._middlewares_handlers = tuple(self._prepare_middleware()) + self._has_legacy_middlewares = any( + not new_style for _, new_style in self._middlewares_handlers + ) # If current app and any subapp do not have middlewares avoid run all # of the code footprint that it implies, which have a middleware @@ -525,14 +544,17 @@ async def _handle(self, request: Request) -> StreamResponse: handler = match_info.handler if self._run_middlewares: - for app in match_info.apps[::-1]: - for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] - if new_style: - handler = update_wrapper( - partial(m, handler=handler), handler # type: ignore[misc] - ) - else: - handler = await m(app, handler) # type: ignore[arg-type,assignment] + if not self._has_legacy_middlewares: + handler = _build_middlewares(handler, match_info.apps[::-1]) + else: + for app in match_info.apps[::-1]: + for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] + if new_style: + handler = update_wrapper( + partial(m, handler=handler), handler # type: ignore[misc] + ) + else: + handler = await m(app, handler) # type: ignore[arg-type,assignment] resp = await handler(request) diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index dbe23e02035..9c4462be409 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -24,10 +24,13 @@ async def middleware(request, handler: Handler): app.middlewares.append(middleware) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) - resp = await client.get("/") - assert 201 == resp.status - txt = await resp.text() - assert "OK[MIDDLEWARE]" == txt + + # Call twice to verify cache works + for _ in range(2): + resp = await client.get("/") + assert 201 == resp.status + txt = await resp.text() + assert "OK[MIDDLEWARE]" == txt async def test_middleware_handles_exception(loop, aiohttp_client) -> None: @@ -44,10 +47,13 @@ async def middleware(request, handler: Handler): app.middlewares.append(middleware) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) - resp = await client.get("/") - assert 501 == resp.status - txt = await resp.text() - assert "Error text[MIDDLEWARE]" == txt + + # Call twice to verify cache works + for _ in range(2): + resp = await client.get("/") + assert 501 == resp.status + txt = await resp.text() + assert "Error text[MIDDLEWARE]" == txt async def test_middleware_chain(loop, aiohttp_client) -> None: