From 006fbe03fede4eaa1eeba7b8393cbf4d63cb44b6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 20 Feb 2024 14:46:57 -0600 Subject: [PATCH] Avoid creating a task to do DNS resolution if there is no throttle (#8163) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко) --- CHANGES/8163.bugfix.rst | 5 +++++ aiohttp/connector.py | 50 +++++++++++++++++++++++++++++------------ tests/test_connector.py | 6 +++++ 3 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 CHANGES/8163.bugfix.rst diff --git a/CHANGES/8163.bugfix.rst b/CHANGES/8163.bugfix.rst new file mode 100644 index 00000000000..8bfb10260c6 --- /dev/null +++ b/CHANGES/8163.bugfix.rst @@ -0,0 +1,5 @@ +Improved the DNS resolution performance on cache hit +-- by :user:`bdraco`. + +This is achieved by avoiding an :mod:`asyncio` task creation +in this case. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 47c32cd1f3e..2dea12d7adf 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -814,6 +814,7 @@ def clear_dns_cache( async def _resolve_host( self, host: str, port: int, traces: Optional[List["Trace"]] = None ) -> List[Dict[str, Any]]: + """Resolve host and return list of addresses.""" if is_ip_address(host): return [ { @@ -840,8 +841,7 @@ async def _resolve_host( return res key = (host, port) - - if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)): + if key in self._cached_hosts and not self._cached_hosts.expired(key): # get result early, before any await (#4014) result = self._cached_hosts.next_addrs(key) @@ -850,6 +850,39 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result + # + # If multiple connectors are resolving the same host, we wait + # for the first one to resolve and then use the result for all of them. + # We use a throttle event to ensure that we only resolve the host once + # and then use the result for all the waiters. + # + # In this case we need to create a task to ensure that we can shield + # the task from cancellation as cancelling this lookup should not cancel + # the underlying lookup or else the cancel event will get broadcast to + # all the waiters across all connections. + # + resolved_host_task = asyncio.create_task( + self._resolve_host_with_throttle(key, host, port, traces) + ) + try: + return await asyncio.shield(resolved_host_task) + except asyncio.CancelledError: + + def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + with suppress(Exception, asyncio.CancelledError): + fut.result() + + resolved_host_task.add_done_callback(drop_exception) + raise + + async def _resolve_host_with_throttle( + self, + key: Tuple[str, int], + host: str, + port: int, + traces: Optional[List["Trace"]], + ) -> List[Dict[str, Any]]: + """Resolve host with a dns events throttle.""" if key in self._throttle_dns_events: # get event early, before any await (#4014) event = self._throttle_dns_events[key] @@ -1136,22 +1169,11 @@ async def _create_direct_connection( host = host.rstrip(".") + "." port = req.port assert port is not None - host_resolved = asyncio.ensure_future( - self._resolve_host(host, port, traces=traces), loop=self._loop - ) try: # Cancelling this lookup should not cancel the underlying lookup # or else the cancel event will get broadcast to all the waiters # across all connections. - hosts = await asyncio.shield(host_resolved) - except asyncio.CancelledError: - - def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: - with suppress(Exception, asyncio.CancelledError): - fut.result() - - host_resolved.add_done_callback(drop_exception) - raise + hosts = await self._resolve_host(host, port, traces=traces) except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise diff --git a/tests/test_connector.py b/tests/test_connector.py index a7b92ebbd21..00faeaf44d1 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1021,6 +1021,7 @@ async def test_tcp_connector_dns_throttle_requests( loop.create_task(conn._resolve_host("localhost", 8080)) loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) + await asyncio.sleep(0) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) @@ -1032,6 +1033,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop: Any) - r1 = loop.create_task(conn._resolve_host("localhost", 8080)) r2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) assert r1.exception() == e assert r2.exception() == e @@ -1045,6 +1049,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( loop.create_task(conn._resolve_host("localhost", 8080)) f = loop.create_task(conn._resolve_host("localhost", 8080)) + await asyncio.sleep(0) await asyncio.sleep(0) await conn.close() @@ -1212,6 +1217,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests( loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) await asyncio.sleep(0) + await asyncio.sleep(0) on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") )