Skip to content

Commit 6b28b98

Browse files
committed
Support specifying SNI tls_server_name
See geldata/gel-js#1088
1 parent 5b5be68 commit 6b28b98

File tree

5 files changed

+51
-3
lines changed

5 files changed

+51
-3
lines changed

edgedb/asyncio_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ async def _connect_addr(self, addr):
8888
else:
8989
try:
9090
tr, pr = await self._loop.create_connection(
91-
self._protocol_factory, *addr, ssl=self._params.ssl_ctx
91+
self._protocol_factory,
92+
*addr,
93+
ssl=self._params.ssl_ctx,
94+
server_hostname=(
95+
self._params.tls_server_name or addr[0]
96+
),
9297
)
9398
except ssl.CertificateError as e:
9499
raise con_utils.wrap_error(e) from e

edgedb/base_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def __init__(
642642
tls_ca: str = None,
643643
tls_ca_file: str = None,
644644
tls_security: str = None,
645+
tls_server_name: str = None,
645646
wait_until_available: int = 30,
646647
timeout: int = 10,
647648
**kwargs,
@@ -662,6 +663,7 @@ def __init__(
662663
"tls_ca": tls_ca,
663664
"tls_ca_file": tls_ca_file,
664665
"tls_security": tls_security,
666+
"tls_server_name": tls_server_name,
665667
"wait_until_available": wait_until_available,
666668
}
667669

edgedb/blocking_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ async def _connect_addr(self, sock, addr, sa, deadline):
102102
sock.settimeout(time_left)
103103
try:
104104
sock = self._params.ssl_ctx.wrap_socket(
105-
sock, server_hostname=addr[0]
105+
sock,
106+
server_hostname=(
107+
self._params.tls_server_name or addr[0]
108+
),
106109
)
107110
except ssl.CertificateError as e:
108111
raise con_utils.wrap_error(e) from e

edgedb/con_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class ResolvedConnectConfig:
202202
_tls_ca_data = None
203203
_tls_ca_data_source = None
204204

205+
_tls_server_name = None
205206
_tls_security = None
206207
_tls_security_source = None
207208

@@ -254,6 +255,9 @@ def read_ca_file(file_path):
254255

255256
self._set_param('tls_ca_data', ca_file, source, read_ca_file)
256257

258+
def set_tls_server_name(self, ca_data, source):
259+
self._set_param('tls_server_name', ca_data, source)
260+
257261
def set_tls_security(self, security, source):
258262
self._set_param('tls_security', security, source,
259263
_validate_tls_security)
@@ -308,6 +312,10 @@ def password(self):
308312
def secret_key(self):
309313
return self._secret_key
310314

315+
@property
316+
def tls_server_name(self):
317+
return self._tls_server_name
318+
311319
@property
312320
def tls_security(self):
313321
tls_security = self._tls_security or 'default'
@@ -555,6 +563,7 @@ def _parse_connect_dsn_and_args(
555563
tls_ca,
556564
tls_ca_file,
557565
tls_security,
566+
tls_server_name,
558567
server_settings,
559568
wait_until_available,
560569
):
@@ -618,6 +627,10 @@ def _parse_connect_dsn_and_args(
618627
(tls_security, '"tls_security" option')
619628
if tls_security is not None else None
620629
),
630+
tls_server_name=(
631+
(tls_server_name, '"tls_server_name" option')
632+
if tls_server_name is not None else None
633+
),
621634
server_settings=(
622635
(server_settings, '"server_settings" option')
623636
if server_settings is not None else None
@@ -655,6 +668,7 @@ def _parse_connect_dsn_and_args(
655668
env_secret_key = os.getenv('EDGEDB_SECRET_KEY')
656669
env_tls_ca = os.getenv('EDGEDB_TLS_CA')
657670
env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE')
671+
env_tls_server_name = os.getenv('EDGEDB_TLS_SERVER_NAME')
658672
env_tls_security = os.getenv('EDGEDB_CLIENT_TLS_SECURITY')
659673
env_wait_until_available = os.getenv('EDGEDB_WAIT_UNTIL_AVAILABLE')
660674

@@ -717,6 +731,11 @@ def _parse_connect_dsn_and_args(
717731
'"EDGEDB_CLIENT_TLS_SECURITY" environment variable')
718732
if env_tls_security is not None else None
719733
),
734+
tls_server_name=(
735+
(env_tls_server_name,
736+
'"EDGEDB_TLS_SERVER_NAME" environment variable')
737+
if env_tls_server_name is not None else None
738+
),
720739
wait_until_available=(
721740
(
722741
env_wait_until_available,
@@ -924,6 +943,12 @@ def strip_leading_slash(str):
924943
resolved_config._tls_ca_data, resolved_config.set_tls_ca_file
925944
)
926945

946+
handle_dsn_part(
947+
'tls_server_name', None,
948+
resolved_config._tls_server_name,
949+
resolved_config.set_tls_server_name
950+
)
951+
927952
handle_dsn_part(
928953
'tls_security', None,
929954
resolved_config._tls_security,
@@ -1017,6 +1042,7 @@ def _resolve_config_options(
10171042
tls_ca=None,
10181043
tls_ca_file=None,
10191044
tls_security=None,
1045+
tls_server_name=None,
10201046
server_settings=None,
10211047
wait_until_available=None,
10221048
cloud_profile=None,
@@ -1051,6 +1077,8 @@ def _resolve_config_options(
10511077
resolved_config.set_tls_ca_data(*tls_ca)
10521078
if tls_security is not None:
10531079
resolved_config.set_tls_security(*tls_security)
1080+
if tls_server_name is not None:
1081+
resolved_config.set_tls_server_name(*tls_server_name)
10541082
if server_settings is not None:
10551083
resolved_config.add_server_settings(server_settings[0])
10561084
if wait_until_available is not None:
@@ -1178,6 +1206,7 @@ def parse_connect_arguments(
11781206
tls_ca,
11791207
tls_ca_file,
11801208
tls_security,
1209+
tls_server_name,
11811210
timeout,
11821211
command_timeout,
11831212
wait_until_available,
@@ -1211,6 +1240,7 @@ def parse_connect_arguments(
12111240
tls_ca=tls_ca,
12121241
tls_ca_file=tls_ca_file,
12131242
tls_security=tls_security,
1243+
tls_server_name=tls_server_name,
12141244
server_settings=server_settings,
12151245
wait_until_available=wait_until_available,
12161246
)

tests/test_con_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def run_testcase(self, testcase):
127127
tls_ca = opts.get('tlsCA')
128128
tls_ca_file = opts.get('tlsCAFile')
129129
tls_security = opts.get('tlsSecurity')
130+
tls_server_name = opts.get('tlsServerName')
130131
server_settings = opts.get('serverSettings')
131132
wait_until_available = opts.get('waitUntilAvailable')
132133

@@ -241,6 +242,7 @@ def mocked_open(filepath, *args, **kwargs):
241242
tls_ca=tls_ca,
242243
tls_ca_file=tls_ca_file,
243244
tls_security=tls_security,
245+
tls_server_name=tls_server_name,
244246
timeout=timeout,
245247
command_timeout=command_timeout,
246248
server_settings=server_settings,
@@ -259,7 +261,10 @@ def mocked_open(filepath, *args, **kwargs):
259261
'tlsCAData': connect_config._tls_ca_data,
260262
'tlsSecurity': connect_config.tls_security,
261263
'serverSettings': connect_config.server_settings,
262-
'waitUntilAvailable': client_config.wait_until_available,
264+
'waitUntilAvailable': float(
265+
client_config.wait_until_available
266+
),
267+
'tlsServerName': connect_config.tls_server_name,
263268
}
264269

265270
if expected is not None:
@@ -312,6 +317,7 @@ def test_test_connect_params_run_testcase_01(self):
312317
'tlsSecurity': 'strict',
313318
'serverSettings': {},
314319
'waitUntilAvailable': 30,
320+
'tlsServerName': None,
315321
},
316322
})
317323

@@ -336,6 +342,7 @@ def test_test_connect_params_run_testcase_02(self):
336342
'tlsSecurity': 'strict',
337343
'serverSettings': {},
338344
'waitUntilAvailable': 30,
345+
'tlsServerName': None,
339346
},
340347
})
341348

@@ -431,6 +438,7 @@ def test_project_config(self):
431438
tls_ca=None,
432439
tls_ca_file=None,
433440
tls_security=None,
441+
tls_server_name=None,
434442
timeout=10,
435443
command_timeout=None,
436444
server_settings=None,

0 commit comments

Comments
 (0)