Skip to content

Commit e868a11

Browse files
poiujdvora-h
authored andcommitted
Allow to control the minimum SSL version (#3127)
* Allow to control the minimum SSL version It's useful for applications that has strict security requirements. * Add tests for minimum SSL version The commit updates test_tcp_ssl_connect for both sync and async connections. Now it sets the minimum SSL version. The test is ran with both TLSv1.2 and TLSv1.3 (if supported). A new test case is test_tcp_ssl_version_mismatch. The test added for both sync and async connections. It uses TLS 1.3 on the client side, and TLS 1.2 on the server side. It expects a connection error. The test is skipped if TLS 1.3 is not supported. * Add example of using a minimum TLS version
1 parent a54617c commit e868a11

File tree

9 files changed

+161
-10
lines changed

9 files changed

+161
-10
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Allow to control the minimum SSL version
12
* Add an optional lock_name attribute to LockError.
23
* Fix return types for `get`, `set_path` and `strappend` in JSONCommands
34
* Connection.register_connect_callback() is made public.

docs/examples/ssl_connection_examples.ipynb

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,42 @@
7676
"ssl_connection.ping()"
7777
]
7878
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"## Connecting to a Redis instance via SSL, while specifying a minimum TLS version"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [
91+
{
92+
"data": {
93+
"text/plain": [
94+
"True"
95+
]
96+
},
97+
"execution_count": 6,
98+
"metadata": {},
99+
"output_type": "execute_result"
100+
}
101+
],
102+
"source": [
103+
"import redis\n",
104+
"import ssl\n",
105+
"\n",
106+
"ssl_conn = redis.Redis(\n",
107+
" host=\"localhost\",\n",
108+
" port=6666,\n",
109+
" ssl=True,\n",
110+
" ssl_min_version=ssl.TLSVersion.TLSv1_3,\n",
111+
")\n",
112+
"ssl_conn.ping()"
113+
]
114+
},
79115
{
80116
"cell_type": "markdown",
81117
"metadata": {},

redis/asyncio/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import inspect
44
import re
5+
import ssl
56
import warnings
67
from typing import (
78
TYPE_CHECKING,
@@ -219,6 +220,7 @@ def __init__(
219220
ssl_ca_certs: Optional[str] = None,
220221
ssl_ca_data: Optional[str] = None,
221222
ssl_check_hostname: bool = False,
223+
ssl_min_version: Optional[ssl.TLSVersion] = None,
222224
max_connections: Optional[int] = None,
223225
single_connection_client: bool = False,
224226
health_check_interval: int = 0,
@@ -311,6 +313,7 @@ def __init__(
311313
"ssl_ca_certs": ssl_ca_certs,
312314
"ssl_ca_data": ssl_ca_data,
313315
"ssl_check_hostname": ssl_check_hostname,
316+
"ssl_min_version": ssl_min_version,
314317
}
315318
)
316319
# This arg only used if no pool is passed in

redis/asyncio/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import collections
33
import random
44
import socket
5+
import ssl
56
import warnings
67
from typing import (
78
Any,
@@ -265,6 +266,7 @@ def __init__(
265266
ssl_certfile: Optional[str] = None,
266267
ssl_check_hostname: bool = False,
267268
ssl_keyfile: Optional[str] = None,
269+
ssl_min_version: Optional[ssl.TLSVersion] = None,
268270
protocol: Optional[int] = 2,
269271
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
270272
) -> None:
@@ -323,6 +325,7 @@ def __init__(
323325
"ssl_certfile": ssl_certfile,
324326
"ssl_check_hostname": ssl_check_hostname,
325327
"ssl_keyfile": ssl_keyfile,
328+
"ssl_min_version": ssl_min_version,
326329
}
327330
)
328331

redis/asyncio/connection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ def __init__(
738738
ssl_ca_certs: Optional[str] = None,
739739
ssl_ca_data: Optional[str] = None,
740740
ssl_check_hostname: bool = False,
741+
ssl_min_version: Optional[ssl.TLSVersion] = None,
741742
**kwargs,
742743
):
743744
self.ssl_context: RedisSSLContext = RedisSSLContext(
@@ -747,6 +748,7 @@ def __init__(
747748
ca_certs=ssl_ca_certs,
748749
ca_data=ssl_ca_data,
749750
check_hostname=ssl_check_hostname,
751+
min_version=ssl_min_version,
750752
)
751753
super().__init__(**kwargs)
752754

@@ -779,6 +781,10 @@ def ca_data(self):
779781
def check_hostname(self):
780782
return self.ssl_context.check_hostname
781783

784+
@property
785+
def min_version(self):
786+
return self.ssl_context.min_version
787+
782788

783789
class RedisSSLContext:
784790
__slots__ = (
@@ -789,6 +795,7 @@ class RedisSSLContext:
789795
"ca_data",
790796
"context",
791797
"check_hostname",
798+
"min_version",
792799
)
793800

794801
def __init__(
@@ -799,6 +806,7 @@ def __init__(
799806
ca_certs: Optional[str] = None,
800807
ca_data: Optional[str] = None,
801808
check_hostname: bool = False,
809+
min_version: Optional[ssl.TLSVersion] = None,
802810
):
803811
self.keyfile = keyfile
804812
self.certfile = certfile
@@ -818,6 +826,7 @@ def __init__(
818826
self.ca_certs = ca_certs
819827
self.ca_data = ca_data
820828
self.check_hostname = check_hostname
829+
self.min_version = min_version
821830
self.context: Optional[ssl.SSLContext] = None
822831

823832
def get(self) -> ssl.SSLContext:
@@ -829,6 +838,8 @@ def get(self) -> ssl.SSLContext:
829838
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
830839
if self.ca_certs or self.ca_data:
831840
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
841+
if self.min_version is not None:
842+
context.minimum_version = self.min_version
832843
self.context = context
833844
return self.context
834845

redis/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def __init__(
192192
ssl_validate_ocsp_stapled=False,
193193
ssl_ocsp_context=None,
194194
ssl_ocsp_expected_cert=None,
195+
ssl_min_version=None,
195196
max_connections=None,
196197
single_connection_client=False,
197198
health_check_interval=0,
@@ -291,6 +292,7 @@ def __init__(
291292
"ssl_validate_ocsp": ssl_validate_ocsp,
292293
"ssl_ocsp_context": ssl_ocsp_context,
293294
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
295+
"ssl_min_version": ssl_min_version,
294296
}
295297
)
296298
connection_pool = ConnectionPool(**kwargs)

redis/connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def __init__(
684684
ssl_validate_ocsp_stapled=False,
685685
ssl_ocsp_context=None,
686686
ssl_ocsp_expected_cert=None,
687+
ssl_min_version=None,
687688
**kwargs,
688689
):
689690
"""Constructor
@@ -702,6 +703,7 @@ def __init__(
702703
ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
703704
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
704705
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
706+
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
705707
706708
Raises:
707709
RedisError
@@ -734,6 +736,7 @@ def __init__(
734736
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
735737
self.ssl_ocsp_context = ssl_ocsp_context
736738
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
739+
self.ssl_min_version = ssl_min_version
737740
super().__init__(**kwargs)
738741

739742
def _connect(self):
@@ -756,6 +759,8 @@ def _connect(self):
756759
context.load_verify_locations(
757760
cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
758761
)
762+
if self.ssl_min_version is not None:
763+
context.minimum_version = self.ssl_min_version
759764
sslsock = context.wrap_socket(sock, server_hostname=self.host)
760765
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
761766
raise RedisError("cryptography is not installed.")

tests/test_asyncio/test_connect.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SSLConnection,
1111
UnixDomainSocketConnection,
1212
)
13+
from redis.exceptions import ConnectionError
1314

1415
from ..ssl_utils import get_ssl_filename
1516

@@ -50,7 +51,17 @@ async def test_uds_connect(uds_address):
5051

5152

5253
@pytest.mark.ssl
53-
async def test_tcp_ssl_connect(tcp_address):
54+
@pytest.mark.parametrize(
55+
"ssl_min_version",
56+
[
57+
ssl.TLSVersion.TLSv1_2,
58+
pytest.param(
59+
ssl.TLSVersion.TLSv1_3,
60+
marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"),
61+
),
62+
],
63+
)
64+
async def test_tcp_ssl_connect(tcp_address, ssl_min_version):
5465
host, port = tcp_address
5566
certfile = get_ssl_filename("server-cert.pem")
5667
keyfile = get_ssl_filename("server-key.pem")
@@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address):
6071
client_name=_CLIENT_NAME,
6172
ssl_ca_certs=certfile,
6273
socket_timeout=10,
74+
ssl_min_version=ssl_min_version,
6375
)
6476
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
6577
await conn.disconnect()
6678

6779

68-
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
80+
@pytest.mark.ssl
81+
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
82+
async def test_tcp_ssl_version_mismatch(tcp_address):
83+
host, port = tcp_address
84+
certfile = get_ssl_filename("server-cert.pem")
85+
keyfile = get_ssl_filename("server-key.pem")
86+
conn = SSLConnection(
87+
host=host,
88+
port=port,
89+
client_name=_CLIENT_NAME,
90+
ssl_ca_certs=certfile,
91+
socket_timeout=1,
92+
ssl_min_version=ssl.TLSVersion.TLSv1_3,
93+
)
94+
with pytest.raises(ConnectionError):
95+
await _assert_connect(
96+
conn,
97+
tcp_address,
98+
certfile=certfile,
99+
keyfile=keyfile,
100+
ssl_version=ssl.TLSVersion.TLSv1_2,
101+
)
102+
await conn.disconnect()
103+
104+
105+
async def _assert_connect(
106+
conn,
107+
server_address,
108+
certfile=None,
109+
keyfile=None,
110+
ssl_version=None,
111+
):
69112
stop_event = asyncio.Event()
70113
finished = asyncio.Event()
71114

@@ -82,7 +125,9 @@ async def _handler(reader, writer):
82125
elif certfile:
83126
host, port = server_address
84127
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
85-
context.minimum_version = ssl.TLSVersion.TLSv1_2
128+
if ssl_version is not None:
129+
context.minimum_version = ssl_version
130+
context.maximum_version = ssl_version
86131
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
87132
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
88133
else:
@@ -94,6 +139,9 @@ async def _handler(reader, writer):
94139
try:
95140
await conn.connect()
96141
await conn.disconnect()
142+
except ConnectionError:
143+
finished.set()
144+
raise
97145
finally:
98146
stop_event.set()
99147
aserver.close()

0 commit comments

Comments
 (0)