Skip to content

Commit

Permalink
Fix SSLContext creation in the TCPConnector with multiple loops (#9029)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 6, 2024
1 parent 301a1cd commit 466448c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 114 deletions.
100 changes: 40 additions & 60 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,7 @@
)
from .client_proto import ResponseHandler
from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint
from .helpers import (
_SENTINEL,
ceil_timeout,
is_ip_address,
sentinel,
set_exception,
set_result,
)
from .helpers import _SENTINEL, ceil_timeout, is_ip_address, sentinel, set_result
from .locks import EventResultOrError
from .resolver import DefaultResolver

Expand Down Expand Up @@ -729,6 +722,35 @@ def expired(self, key: Tuple[str, int]) -> bool:
return self._timestamps[key] + self._ttl < monotonic()


def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.
This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if ssl is None:
# No ssl support
return None
if verified:
return ssl.create_default_context()
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext


# The default SSLContext objects are created at import time
# since they do blocking I/O to load certificates from disk,
# and imports should always be done before the event loop starts
# or in a thread.
_SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)


class TCPConnector(BaseConnector):
"""TCP connector.
Expand Down Expand Up @@ -759,7 +781,6 @@ class TCPConnector(BaseConnector):
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
_made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {}

def __init__(
self,
Expand Down Expand Up @@ -963,25 +984,7 @@ async def _create_connection(

return proto

@staticmethod
def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.
This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if verified:
return ssl.create_default_context()
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext

async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
Expand All @@ -1005,35 +1008,14 @@ async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return _SSL_CONTEXT_UNVERIFIED
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return await self._make_or_get_ssl_context(True)

async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext:
"""Create or get cached SSL context."""
try:
return await self._made_ssl_context[verified]
except KeyError:
loop = self._loop
future = loop.create_future()
self._made_ssl_context[verified] = future
try:
result = await loop.run_in_executor(
None, self._make_ssl_context, verified
)
# BaseException is used since we might get CancelledError
except BaseException as ex:
del self._made_ssl_context[verified]
set_exception(future, ex)
raise
else:
set_result(future, result)
return result
return _SSL_CONTEXT_UNVERIFIED
return _SSL_CONTEXT_VERIFIED

def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
Expand Down Expand Up @@ -1120,13 +1102,11 @@ async def _start_tls_connection(
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
"""Wrap the raw TCP transport with TLS."""
tls_proto = self._factory() # Create a brand new proto for TLS

# Safety of the `cast()` call here is based on the fact that
# internally `_get_ssl_context()` only returns `None` when
# `req.is_ssl()` evaluates to `False` which is never gonna happen
# in this code path. Of course, it's rather fragile
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req))
sslcontext = self._get_ssl_context(req)
if TYPE_CHECKING:
# _start_tls_connection is unreachable in the current code path
# if sslcontext is None.
assert sslcontext is not None

try:
async with ceil_timeout(
Expand Down Expand Up @@ -1204,7 +1184,7 @@ async def _create_direct_connection(
*,
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.Transport, ResponseHandler]:
sslcontext = await self._get_ssl_context(req)
sslcontext = self._get_ssl_context(req)
fingerprint = self._get_fingerprint(req)

host = req.url.raw_host
Expand Down
115 changes: 64 additions & 51 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Tests of http client with custom Connector

import asyncio
import gc
import hashlib
Expand All @@ -9,18 +8,19 @@
import sys
import uuid
from collections import deque
from concurrent import futures
from contextlib import closing, suppress
from typing import (
Awaitable,
Callable,
Dict,
Iterator,
List,
Literal,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
)
from unittest import mock

Expand All @@ -30,11 +30,23 @@
from yarl import URL

import aiohttp
from aiohttp import ClientRequest, ClientSession, ClientTimeout, web
from aiohttp import (
ClientRequest,
ClientSession,
ClientTimeout,
connector as connector_module,
web,
)
from aiohttp.abc import ResolveResult
from aiohttp.client_proto import ResponseHandler
from aiohttp.client_reqrep import ConnectionKey
from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable
from aiohttp.connector import (
_SSL_CONTEXT_UNVERIFIED,
_SSL_CONTEXT_VERIFIED,
Connection,
TCPConnector,
_DNSCacheTable,
)
from aiohttp.locks import EventResultOrError
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
from aiohttp.test_utils import make_mocked_coro, unused_port
Expand Down Expand Up @@ -1710,23 +1722,11 @@ async def test_tcp_connector_clear_dns_cache_bad_args(
conn.clear_dns_cache("localhost")


async def test_dont_recreate_ssl_context() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(True)
assert ctx is await conn._make_or_get_ssl_context(True)


async def test_dont_recreate_ssl_context2() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(False)
assert ctx is await conn._make_or_get_ssl_context(False)


async def test___get_ssl_context1() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = False
assert await conn._get_ssl_context(req) is None
assert conn._get_ssl_context(req) is None


async def test___get_ssl_context2() -> None:
Expand All @@ -1735,7 +1735,7 @@ async def test___get_ssl_context2() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = ctx
assert await conn._get_ssl_context(req) is ctx
assert conn._get_ssl_context(req) is ctx


async def test___get_ssl_context3() -> None:
Expand All @@ -1744,7 +1744,7 @@ async def test___get_ssl_context3() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn._get_ssl_context(req) is ctx
assert conn._get_ssl_context(req) is ctx


async def test___get_ssl_context4() -> None:
Expand All @@ -1753,9 +1753,7 @@ async def test___get_ssl_context4() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = False
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED


async def test___get_ssl_context5() -> None:
Expand All @@ -1764,17 +1762,15 @@ async def test___get_ssl_context5() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest())
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED


async def test___get_ssl_context6() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True)
assert conn._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED


async def test_ssl_context_once() -> None:
Expand All @@ -1786,31 +1782,9 @@ async def test_ssl_context_once() -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context
assert True in conn1._made_ssl_context


@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError])
async def test_ssl_context_creation_raises(exception: Type[BaseException]) -> None:
"""Test that we try again if SSLContext creation fails the first time."""
conn = aiohttp.TCPConnector()
conn._made_ssl_context.clear()

with mock.patch.object(
conn, "_make_ssl_context", side_effect=exception
), pytest.raises(exception):
await conn._make_or_get_ssl_context(True)

assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext)
assert conn1._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED
assert conn2._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED
assert conn3._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED


async def test_close_twice(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None:
Expand Down Expand Up @@ -2977,3 +2951,42 @@ async def allow_connection_and_add_dummy_waiter() -> None:
)

await connector.close()


def test_connector_multiple_event_loop() -> None:
"""Test the connector with multiple event loops."""

async def async_connect() -> Literal[True]:
conn = aiohttp.TCPConnector()
loop = asyncio.get_running_loop()
req = ClientRequest("GET", URL("https://127.0.0.1"), loop=loop)
with suppress(aiohttp.ClientConnectorError):
with mock.patch.object(
conn._loop,
"create_connection",
autospec=True,
spec_set=True,
side_effect=ssl.CertificateError,
):
await conn.connect(req, [], ClientTimeout())
return True

def test_connect() -> Literal[True]:
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(async_connect())
finally:
loop.close()

with futures.ThreadPoolExecutor() as executor:
res_list = [executor.submit(test_connect) for _ in range(2)]
raw_response_list = [res.result() for res in futures.as_completed(res_list)]

assert raw_response_list == [True, True]


def test_default_ssl_context_creation_without_ssl() -> None:
"""Verify _make_ssl_context does not raise when ssl is not available."""
with mock.patch.object(connector_module, "ssl", None):
assert connector_module._make_ssl_context(False) is None
assert connector_module._make_ssl_context(True) is None
5 changes: 2 additions & 3 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import aiohttp
from aiohttp.client_reqrep import ClientRequest, ClientResponse
from aiohttp.connector import _SSL_CONTEXT_VERIFIED
from aiohttp.helpers import TimerNoop
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -934,9 +935,7 @@ async def make_conn() -> aiohttp.TCPConnector:
tls_m.assert_called_with(
mock.ANY,
mock.ANY,
self.loop.run_until_complete(
connector._make_or_get_ssl_context(True)
),
_SSL_CONTEXT_VERIFIED,
server_hostname="www.python.org",
ssl_handshake_timeout=mock.ANY,
)
Expand Down

0 comments on commit 466448c

Please sign in to comment.