Skip to content

Commit

Permalink
[PR #9029/466448c backport][3.10] Fix SSLContext creation in the TCPC…
Browse files Browse the repository at this point in the history
…onnector with multiple loops (#9042)
  • Loading branch information
bdraco authored Sep 6, 2024
1 parent b48ebc1 commit 4d022e4
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 113 deletions.
100 changes: 40 additions & 60 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,7 @@
)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
from .helpers import (
ceil_timeout,
is_ip_address,
noop,
sentinel,
set_exception,
set_result,
)
from .helpers import ceil_timeout, is_ip_address, noop, sentinel
from .locks import EventResultOrError
from .resolver import DefaultResolver

Expand Down Expand Up @@ -748,6 +741,35 @@ def expired(self, key: Tuple[str, int]) -> bool:
return self._timestamps[key] + self._ttl < monotonic()


def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.
This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if ssl is None:
# No ssl support
return None
if verified:
return ssl.create_default_context()
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext


# The default SSLContext objects are created at import time
# since they do blocking I/O to load certificates from disk,
# and imports should always be done before the event loop starts
# or in a thread.
_SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)


class TCPConnector(BaseConnector):
"""TCP connector.
Expand Down Expand Up @@ -778,7 +800,6 @@ class TCPConnector(BaseConnector):
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
_made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {}

def __init__(
self,
Expand Down Expand Up @@ -982,25 +1003,7 @@ async def _create_connection(

return proto

@staticmethod
def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.
This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if verified:
return ssl.create_default_context()
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext

async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
Expand All @@ -1024,35 +1027,14 @@ async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return _SSL_CONTEXT_UNVERIFIED
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return await self._make_or_get_ssl_context(True)

async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext:
"""Create or get cached SSL context."""
try:
return await self._made_ssl_context[verified]
except KeyError:
loop = self._loop
future = loop.create_future()
self._made_ssl_context[verified] = future
try:
result = await loop.run_in_executor(
None, self._make_ssl_context, verified
)
# BaseException is used since we might get CancelledError
except BaseException as ex:
del self._made_ssl_context[verified]
set_exception(future, ex)
raise
else:
set_result(future, result)
return result
return _SSL_CONTEXT_UNVERIFIED
return _SSL_CONTEXT_VERIFIED

def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
Expand Down Expand Up @@ -1204,13 +1186,11 @@ async def _start_tls_connection(
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
"""Wrap the raw TCP transport with TLS."""
tls_proto = self._factory() # Create a brand new proto for TLS

# Safety of the `cast()` call here is based on the fact that
# internally `_get_ssl_context()` only returns `None` when
# `req.is_ssl()` evaluates to `False` which is never gonna happen
# in this code path. Of course, it's rather fragile
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req))
sslcontext = self._get_ssl_context(req)
if TYPE_CHECKING:
# _start_tls_connection is unreachable in the current code path
# if sslcontext is None.
assert sslcontext is not None

try:
async with ceil_timeout(
Expand Down Expand Up @@ -1288,7 +1268,7 @@ async def _create_direct_connection(
*,
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.Transport, ResponseHandler]:
sslcontext = await self._get_ssl_context(req)
sslcontext = self._get_ssl_context(req)
fingerprint = self._get_fingerprint(req)

host = req.url.raw_host
Expand Down
116 changes: 64 additions & 52 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Tests of http client with custom Connector

import asyncio
import gc
import hashlib
Expand All @@ -9,19 +8,31 @@
import sys
import uuid
from collections import deque
from concurrent import futures
from contextlib import closing, suppress
from typing import Any, List, Optional, Type
from typing import Any, List, Literal, Optional
from unittest import mock

import pytest
from aiohappyeyeballs import AddrInfoType
from yarl import URL

import aiohttp
from aiohttp import client, web
from aiohttp.client import ClientRequest, ClientTimeout
from aiohttp import (
ClientRequest,
ClientTimeout,
client,
connector as connector_module,
web,
)
from aiohttp.client_reqrep import ConnectionKey
from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable
from aiohttp.connector import (
_SSL_CONTEXT_UNVERIFIED,
_SSL_CONTEXT_VERIFIED,
Connection,
TCPConnector,
_DNSCacheTable,
)
from aiohttp.locks import EventResultOrError
from aiohttp.test_utils import make_mocked_coro, unused_port
from aiohttp.tracing import Trace
Expand Down Expand Up @@ -1540,23 +1551,11 @@ async def test_tcp_connector_clear_dns_cache_bad_args(loop) -> None:
conn.clear_dns_cache("localhost")


async def test_dont_recreate_ssl_context() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(True)
assert ctx is await conn._make_or_get_ssl_context(True)


async def test_dont_recreate_ssl_context2() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(False)
assert ctx is await conn._make_or_get_ssl_context(False)


async def test___get_ssl_context1() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = False
assert await conn._get_ssl_context(req) is None
assert conn._get_ssl_context(req) is None


async def test___get_ssl_context2(loop) -> None:
Expand All @@ -1565,7 +1564,7 @@ async def test___get_ssl_context2(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = ctx
assert await conn._get_ssl_context(req) is ctx
assert conn._get_ssl_context(req) is ctx


async def test___get_ssl_context3(loop) -> None:
Expand All @@ -1574,7 +1573,7 @@ async def test___get_ssl_context3(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn._get_ssl_context(req) is ctx
assert conn._get_ssl_context(req) is ctx


async def test___get_ssl_context4(loop) -> None:
Expand All @@ -1583,9 +1582,7 @@ async def test___get_ssl_context4(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = False
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED


async def test___get_ssl_context5(loop) -> None:
Expand All @@ -1594,17 +1591,15 @@ async def test___get_ssl_context5(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest())
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED


async def test___get_ssl_context6() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED


async def test_ssl_context_once() -> None:
Expand All @@ -1616,31 +1611,9 @@ async def test_ssl_context_once() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context
assert True in conn1._made_ssl_context


@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError])
async def test_ssl_context_creation_raises(exception: Type[BaseException]) -> None:
"""Test that we try again if SSLContext creation fails the first time."""
conn = aiohttp.TCPConnector()
conn._made_ssl_context.clear()

with mock.patch.object(
conn, "_make_ssl_context", side_effect=exception
), pytest.raises(exception):
await conn._make_or_get_ssl_context(True)

assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext)
assert conn1._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED
assert conn2._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED
assert conn3._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED


async def test_close_twice(loop) -> None:
Expand Down Expand Up @@ -2717,3 +2690,42 @@ async def allow_connection_and_add_dummy_waiter():
)

await connector.close()


def test_connector_multiple_event_loop() -> None:
"""Test the connector with multiple event loops."""

async def async_connect() -> Literal[True]:
conn = aiohttp.TCPConnector()
loop = asyncio.get_running_loop()
req = ClientRequest("GET", URL("https://127.0.0.1"), loop=loop)
with suppress(aiohttp.ClientConnectorError):
with mock.patch.object(
conn._loop,
"create_connection",
autospec=True,
spec_set=True,
side_effect=ssl.CertificateError,
):
await conn.connect(req, [], ClientTimeout())
return True

def test_connect() -> Literal[True]:
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(async_connect())
finally:
loop.close()

with futures.ThreadPoolExecutor() as executor:
res_list = [executor.submit(test_connect) for _ in range(2)]
raw_response_list = [res.result() for res in futures.as_completed(res_list)]

assert raw_response_list == [True, True]


def test_default_ssl_context_creation_without_ssl() -> None:
"""Verify _make_ssl_context does not raise when ssl is not available."""
with mock.patch.object(connector_module, "ssl", None):
assert connector_module._make_ssl_context(False) is None
assert connector_module._make_ssl_context(True) is None
3 changes: 2 additions & 1 deletion tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import aiohttp
from aiohttp.client_reqrep import ClientRequest, ClientResponse
from aiohttp.connector import _SSL_CONTEXT_VERIFIED
from aiohttp.helpers import TimerNoop
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -817,7 +818,7 @@ async def make_conn():
self.loop.start_tls.assert_called_with(
mock.ANY,
mock.ANY,
self.loop.run_until_complete(connector._make_or_get_ssl_context(True)),
_SSL_CONTEXT_VERIFIED,
server_hostname="www.python.org",
ssl_handshake_timeout=mock.ANY,
)
Expand Down

0 comments on commit 4d022e4

Please sign in to comment.