Skip to content

Commit

Permalink
Make it possible to customize SSL ciphers
Browse files Browse the repository at this point in the history
Given that Python 3.10 changed the default list of SSL ciphers, it is a
good idea in general to allow customization of the list of cyphers when
using Redis with TLS.

It seems that this works only with TLS 1.2, and with TLS 1.3 it's
intentionally not possible to change the ciphers:
https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers
  • Loading branch information
gerzse committed Apr 23, 2024
1 parent 1784b37 commit 94fb28c
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 0 deletions.
2 changes: 2 additions & 0 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
Expand Down Expand Up @@ -314,6 +315,7 @@ def __init__(
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
# This arg only used if no pool is passed in
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
Expand Down Expand Up @@ -326,6 +327,7 @@ def __init__(
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)

Expand Down
7 changes: 7 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
**kwargs,
):
self.ssl_context: RedisSSLContext = RedisSSLContext(
Expand All @@ -749,6 +750,7 @@ def __init__(
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)

Expand Down Expand Up @@ -796,6 +798,7 @@ class RedisSSLContext:
"context",
"check_hostname",
"min_version",
"ciphers",
)

def __init__(
Expand All @@ -807,6 +810,7 @@ def __init__(
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[ssl.TLSVersion] = None,
ciphers: Optional[str] = None,
):
self.keyfile = keyfile
self.certfile = certfile
Expand All @@ -827,6 +831,7 @@ def __init__(
self.ca_data = ca_data
self.check_hostname = check_hostname
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[ssl.SSLContext] = None

def get(self) -> ssl.SSLContext:
Expand All @@ -840,6 +845,8 @@ def get(self) -> ssl.SSLContext:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context

Expand Down
2 changes: 2 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
Expand Down Expand Up @@ -298,6 +299,7 @@ def __init__(
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
connection_pool = ConnectionPool(**kwargs)
Expand Down
5 changes: 5 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
**kwargs,
):
"""Constructor
Expand All @@ -704,6 +705,7 @@ def __init__(
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
Raises:
RedisError
Expand Down Expand Up @@ -737,6 +739,7 @@ def __init__(
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
self.ssl_min_version = ssl_min_version
self.ssl_ciphers = ssl_ciphers
super().__init__(**kwargs)

def _connect(self):
Expand All @@ -761,6 +764,8 @@ def _connect(self):
)
if self.ssl_min_version is not None:
context.minimum_version = self.ssl_min_version
if self.ssl_ciphers:
context.set_ciphers(self.ssl_ciphers)
sslsock = context.wrap_socket(sock, server_hostname=self.host)
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
raise RedisError("cryptography is not installed.")
Expand Down
54 changes: 54 additions & 0 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import binascii
import datetime
import ssl
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -2951,6 +2952,59 @@ async def test_ssl_connection(
async with await create_client(ssl=True, ssl_cert_reqs="none") as rc:
assert await rc.ping()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_ssl_connection_tls12_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
assert await rc.ping()

async def test_ssl_connection_tls12_custom_ciphers_invalid(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
async def test_ssl_connection_tls13_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
# TLSv1.3 does not support changing the ciphers
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

async def test_validating_self_signed_certificate(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@ async def test_uds_connect(uds_address):
await _assert_connect(conn, path)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
await conn.disconnect()


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_min_version",
Expand Down
25 changes: 25 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version):
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
def test_tcp_ssl_version_mismatch(tcp_address):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,68 @@ def test_validating_self_signed_string_certificate(self, request):
)
assert r.ping()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"DHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305",
],
)
def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_3,
ssl_ciphers=ssl_ciphers,
)
assert r.ping()
r.close()

def test_ssl_connection_tls12_custom_ciphers_invalid(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers):
# TLSv1.3 does not support changing the ciphers
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

def _create_oscp_conn(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
Expand Down

0 comments on commit 94fb28c

Please sign in to comment.