Skip to content

Commit 2dc78e5

Browse files
committed
feat: add TLS/mTLS support for experimental host
1 parent ed4735b commit 2dc78e5

File tree

10 files changed

+186
-25
lines changed

10 files changed

+186
-25
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,10 @@ def connect(
736736
route_to_leader_enabled=True,
737737
database_role=None,
738738
experimental_host=None,
739+
use_plain_text=False,
740+
ca_certificate=None,
741+
client_certificate=None,
742+
client_key=None,
739743
**kwargs,
740744
):
741745
"""Creates a connection to a Google Cloud Spanner database.
@@ -789,6 +793,28 @@ def connect(
789793
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
790794
:returns: Connection object associated with the given Google Cloud Spanner
791795
resource.
796+
797+
:type experimental_host: str
798+
:param experimental_host: (Optional) The endpoint for a spanner experimental host deployment.
799+
This is intended only for experimental host spanner endpoints.
800+
801+
:type use_plain_text: bool
802+
:param use_plain_text: (Optional) Whether to use plain text for the connection.
803+
This is intended only for experimental host spanner endpoints.
804+
If not set, the default behavior is to use TLS.
805+
806+
:type ca_certificate: str
807+
:param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection.
808+
This is intended only for experimental host spanner endpoints.
809+
This is mandatory if the experimental_host requires a TLS connection.
810+
:type client_certificate: str
811+
:param client_certificate: (Optional) The path to the client certificate file used for mTLS connection.
812+
This is intended only for experimental host spanner endpoints.
813+
This is mandatory if the experimental_host requires an mTLS connection.
814+
:type client_key: str
815+
:param client_key: (Optional) The path to the client key file used for mTLS connection.
816+
This is intended only for experimental host spanner endpoints.
817+
This is mandatory if the experimental_host requires an mTLS connection.
792818
"""
793819
if client is None:
794820
client_info = ClientInfo(
@@ -817,6 +843,10 @@ def connect(
817843
client_info=client_info,
818844
route_to_leader_enabled=route_to_leader_enabled,
819845
client_options=client_options,
846+
use_plain_text=use_plain_text,
847+
ca_certificate=ca_certificate,
848+
client_certificate=client_certificate,
849+
client_key=client_key,
820850
)
821851
else:
822852
if project is not None and client.project != project:

google/cloud/spanner_v1/_helpers.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,65 @@ def _merge_Transaction_Options(
803803

804804
# Convert protobuf object back into a TransactionOptions instance
805805
return TransactionOptions(merged_pb)
806+
807+
808+
def _create_experimental_host_transport(
809+
transport_factory,
810+
experimental_host,
811+
use_plain_text,
812+
ca_certificate,
813+
client_certificate,
814+
client_key,
815+
interceptors=None,
816+
):
817+
"""Creates an experimental host transport for Spanner.
818+
819+
Args:
820+
transport_factory (type): The transport class to instantiate (e.g.
821+
`SpannerGrpcTransport`).
822+
experimental_host (str): The endpoint for the experimental host.
823+
use_plain_text (bool): Whether to use a plain text (insecure) connection.
824+
ca_certificate (str): Path to the CA certificate file for TLS.
825+
client_certificate (str): Path to the client certificate file for mTLS.
826+
client_key (str): Path to the client key file for mTLS.
827+
interceptors (list): Optional list of interceptors to add to the channel.
828+
829+
Returns:
830+
object: An instance of the transport class created by `transport_factory`.
831+
832+
Raises:
833+
ValueError: If TLS/mTLS configuration is invalid.
834+
"""
835+
import grpc
836+
from google.auth.credentials import AnonymousCredentials
837+
838+
channel = None
839+
if use_plain_text:
840+
channel = grpc.insecure_channel(target=experimental_host)
841+
elif ca_certificate:
842+
with open(ca_certificate, "rb") as f:
843+
ca_cert = f.read()
844+
if client_certificate is not None and client_key is not None:
845+
with open(client_certificate, "rb") as f:
846+
client_cert = f.read()
847+
with open(client_key, "rb") as f:
848+
private_key = f.read()
849+
ssl_creds = grpc.ssl_channel_credentials(
850+
root_certificates=ca_cert,
851+
private_key=private_key,
852+
certificate_chain=client_cert,
853+
)
854+
elif client_certificate is None and client_key is None:
855+
ssl_creds = grpc.ssl_channel_credentials(root_certificates=ca_cert)
856+
else:
857+
raise ValueError(
858+
"Both client_certificate and client_key must be provided for mTLS connection"
859+
)
860+
channel = grpc.secure_channel(experimental_host, ssl_creds)
861+
else:
862+
raise ValueError(
863+
"TLS/mTLS connection requires ca_certificate to be set for experimental_host"
864+
)
865+
if interceptors is not None:
866+
channel = grpc.intercept_channel(channel, *interceptors)
867+
return transport_factory(channel=channel, credentials=AnonymousCredentials())

google/cloud/spanner_v1/client.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848
from google.cloud.spanner_v1 import __version__
4949
from google.cloud.spanner_v1 import ExecuteSqlRequest
5050
from google.cloud.spanner_v1 import DefaultTransactionOptions
51-
from google.cloud.spanner_v1._helpers import _merge_query_options
51+
from google.cloud.spanner_v1._helpers import (
52+
_create_experimental_host_transport,
53+
_merge_query_options,
54+
)
5255
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
5356
from google.cloud.spanner_v1.instance import Instance
5457
from google.cloud.spanner_v1.metrics.constants import (
@@ -186,6 +189,30 @@ class Client(ClientWithProject):
186189
187190
:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
188191
and ``admin`` are :data:`True`
192+
193+
:type use_plain_text: bool
194+
:param use_plain_text: (Optional) Whether to use plain text for the connection.
195+
This is intended only for experimental host spanner endpoints.
196+
If set, this will override the `api_endpoint` in `client_options`.
197+
If not set, the default behavior is to use TLS.
198+
199+
:type ca_certificate: str
200+
:param ca_certificate: (Optional) The path to the CA certificate file used for TLS connection.
201+
This is intended only for experimental host spanner endpoints.
202+
If set, this will override the `api_endpoint` in `client_options`.
203+
This is mandatory if the experimental_host requires a TLS connection.
204+
205+
:type client_certificate: str
206+
:param client_certificate: (Optional) The path to the client certificate file used for mTLS connection.
207+
This is intended only for experimental host spanner endpoints.
208+
If set, this will override the `api_endpoint` in `client_options`.
209+
This is mandatory if the experimental_host requires a mTLS connection.
210+
211+
:type client_key: str
212+
:param client_key: (Optional) The path to the client key file used for mTLS connection.
213+
This is intended only for experimental host spanner endpoints.
214+
If set, this will override the `api_endpoint` in `client_options`.
215+
This is mandatory if the experimental_host requires a mTLS connection.
189216
"""
190217

191218
_instance_admin_api = None
@@ -210,6 +237,10 @@ def __init__(
210237
default_transaction_options: Optional[DefaultTransactionOptions] = None,
211238
experimental_host=None,
212239
disable_builtin_metrics=False,
240+
use_plain_text=False,
241+
ca_certificate=None,
242+
client_certificate=None,
243+
client_key=None,
213244
):
214245
self._emulator_host = _get_spanner_emulator_host()
215246
self._experimental_host = experimental_host
@@ -224,6 +255,12 @@ def __init__(
224255
if self._emulator_host:
225256
credentials = AnonymousCredentials()
226257
elif self._experimental_host:
258+
# For all experimental host endpoints project is default
259+
project = "default"
260+
self._use_plain_text = use_plain_text
261+
self._ca_certificate = ca_certificate
262+
self._client_certificate = client_certificate
263+
self._client_key = client_key
227264
credentials = AnonymousCredentials()
228265
elif isinstance(credentials, AnonymousCredentials):
229266
self._emulator_host = self._client_options.api_endpoint
@@ -259,7 +296,7 @@ def __init__(
259296
):
260297
meter_provider = metrics.NoOpMeterProvider()
261298
try:
262-
if not _get_spanner_emulator_host():
299+
if not _get_spanner_emulator_host() and not self._experimental_host:
263300
meter_provider = MeterProvider(
264301
metric_readers=[
265302
PeriodicExportingMetricReader(
@@ -339,8 +376,13 @@ def instance_admin_api(self):
339376
transport=transport,
340377
)
341378
elif self._experimental_host:
342-
transport = InstanceAdminGrpcTransport(
343-
channel=grpc.insecure_channel(target=self._experimental_host)
379+
transport = _create_experimental_host_transport(
380+
InstanceAdminGrpcTransport,
381+
self._experimental_host,
382+
self._use_plain_text,
383+
self._ca_certificate,
384+
self._client_certificate,
385+
self._client_key,
344386
)
345387
self._instance_admin_api = InstanceAdminClient(
346388
client_info=self._client_info,
@@ -369,8 +411,13 @@ def database_admin_api(self):
369411
transport=transport,
370412
)
371413
elif self._experimental_host:
372-
transport = DatabaseAdminGrpcTransport(
373-
channel=grpc.insecure_channel(target=self._experimental_host)
414+
transport = _create_experimental_host_transport(
415+
DatabaseAdminGrpcTransport,
416+
self._experimental_host,
417+
self._use_plain_text,
418+
self._ca_certificate,
419+
self._client_certificate,
420+
self._client_key,
374421
)
375422
self._database_admin_api = DatabaseAdminClient(
376423
client_info=self._client_info,
@@ -517,7 +564,6 @@ def instance(
517564
self._emulator_host,
518565
labels,
519566
processing_units,
520-
self._experimental_host,
521567
)
522568

523569
def list_instances(self, filter_="", page_size=None):

google/cloud/spanner_v1/database.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_metadata_with_prefix,
5656
_metadata_with_leader_aware_routing,
5757
_metadata_with_request_id,
58+
_create_experimental_host_transport,
5859
)
5960
from google.cloud.spanner_v1.batch import Batch
6061
from google.cloud.spanner_v1.batch import MutationGroups
@@ -203,11 +204,9 @@ def __init__(
203204

204205
self._pool = pool
205206
pool.bind(self)
206-
is_experimental_host = self._instance.experimental_host is not None
207+
self._experimental_host = self._instance._client._experimental_host
207208

208-
self._sessions_manager = DatabaseSessionsManager(
209-
self, pool, is_experimental_host
210-
)
209+
self._sessions_manager = DatabaseSessionsManager(self, pool)
211210

212211
@classmethod
213212
def from_pb(cls, database_pb, instance, pool=None):
@@ -452,9 +451,14 @@ def spanner_api(self):
452451
client_info=client_info, transport=transport
453452
)
454453
return self._spanner_api
455-
if self._instance.experimental_host is not None:
456-
transport = SpannerGrpcTransport(
457-
channel=grpc.insecure_channel(self._instance.experimental_host)
454+
if self._experimental_host is not None:
455+
transport = _create_experimental_host_transport(
456+
SpannerGrpcTransport,
457+
self._experimental_host,
458+
self._instance._client._use_plain_text,
459+
self._instance._client._ca_certificate,
460+
self._instance._client._client_certificate,
461+
self._instance._client._client_key,
458462
)
459463
self._spanner_api = SpannerClient(
460464
client_info=client_info,

google/cloud/spanner_v1/database_sessions_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,9 @@ class DatabaseSessionsManager(object):
6262
_MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10)
6363
_MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7)
6464

65-
def __init__(self, database, pool, is_experimental_host: bool = False):
65+
def __init__(self, database, pool):
6666
self._database = database
6767
self._pool = pool
68-
self._is_experimental_host = is_experimental_host
6968

7069
# Declare multiplexed session attributes. When a multiplexed session for the
7170
# database session manager is created, a maintenance thread is initialized to
@@ -89,7 +88,8 @@ def get_session(self, transaction_type: TransactionType) -> Session:
8988

9089
session = (
9190
self._get_multiplexed_session()
92-
if self._use_multiplexed(transaction_type) or self._is_experimental_host
91+
if self._use_multiplexed(transaction_type)
92+
or self._database._experimental_host is not None
9393
else self._pool.get()
9494
)
9595

google/cloud/spanner_v1/instance.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def __init__(
122122
emulator_host=None,
123123
labels=None,
124124
processing_units=None,
125-
experimental_host=None,
126125
):
127126
self.instance_id = instance_id
128127
self._client = client
@@ -143,7 +142,6 @@ def __init__(
143142
self._node_count = processing_units // PROCESSING_UNITS_PER_NODE
144143
self.display_name = display_name or instance_id
145144
self.emulator_host = emulator_host
146-
self.experimental_host = experimental_host
147145
if labels is None:
148146
labels = {}
149147
self.labels = labels

google/cloud/spanner_v1/testing/database_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import google.auth.credentials
1818
from google.cloud.spanner_admin_database_v1 import DatabaseDialect
1919
from google.cloud.spanner_v1 import SpannerClient
20+
from google.cloud.spanner_v1._helpers import _create_experimental_host_transport
2021
from google.cloud.spanner_v1.database import Database, SPANNER_DATA_SCOPE
2122
from google.cloud.spanner_v1.services.spanner.transports import (
2223
SpannerGrpcTransport,
@@ -86,12 +87,18 @@ def spanner_api(self):
8687
transport=transport,
8788
)
8889
return self._spanner_api
89-
if self._instance.experimental_host is not None:
90-
channel = grpc.insecure_channel(self._instance.experimental_host)
90+
if self._experimental_host is not None:
9191
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
9292
self._interceptors.append(self._x_goog_request_id_interceptor)
93-
channel = grpc.intercept_channel(channel, *self._interceptors)
94-
transport = SpannerGrpcTransport(channel=channel)
93+
transport = _create_experimental_host_transport(
94+
SpannerGrpcTransport,
95+
self._experimental_host,
96+
self._instance._client._use_plain_text,
97+
self._instance._client._ca_certificate,
98+
self._instance._client._client_certificate,
99+
self._instance._client._client_key,
100+
self._interceptors,
101+
)
95102
self._spanner_api = SpannerClient(
96103
client_info=client_info,
97104
transport=transport,

tests/system/_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,14 @@
6060
EXPERIMENTAL_HOST = os.getenv(USE_EXPERIMENTAL_HOST_ENVVAR)
6161
USE_EXPERIMENTAL_HOST = EXPERIMENTAL_HOST is not None
6262

63-
EXPERIMENTAL_HOST_PROJECT = "default"
63+
CA_CERTIFICATE_ENVVAR = "CA_CERTIFICATE"
64+
CA_CERTIFICATE = os.getenv(CA_CERTIFICATE_ENVVAR)
65+
CLIENT_CERTIFICATE_ENVVAR = "CLIENT_CERTIFICATE"
66+
CLIENT_CERTIFICATE = os.getenv(CLIENT_CERTIFICATE_ENVVAR)
67+
CLIENT_KEY_ENVVAR = "CLIENT_KEY"
68+
CLIENT_KEY = os.getenv(CLIENT_KEY_ENVVAR)
69+
USE_PLAIN_TEXT = CA_CERTIFICATE is None
70+
6471
EXPERIMENTAL_HOST_INSTANCE = "default"
6572

6673
DDL_STATEMENTS = (

tests/system/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def spanner_client():
115115

116116
credentials = AnonymousCredentials()
117117
return spanner_v1.Client(
118-
project=_helpers.EXPERIMENTAL_HOST_PROJECT,
118+
use_plain_text=_helpers.USE_PLAIN_TEXT,
119+
ca_certificate=_helpers.CA_CERTIFICATE,
120+
client_certificate=_helpers.CLIENT_CERTIFICATE,
121+
client_key=_helpers.CLIENT_KEY,
119122
credentials=credentials,
120123
experimental_host=_helpers.EXPERIMENTAL_HOST,
121124
)

tests/system/test_dbapi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,10 @@ def test_user_agent(self, shared_instance, dbapi_database):
14421442
experimental_host=_helpers.EXPERIMENTAL_HOST
14431443
if _helpers.USE_EXPERIMENTAL_HOST
14441444
else None,
1445+
use_plain_text=_helpers.USE_PLAIN_TEXT,
1446+
ca_certificate=_helpers.CA_CERTIFICATE,
1447+
client_certificate=_helpers.CLIENT_CERTIFICATE,
1448+
client_key=_helpers.CLIENT_KEY,
14451449
)
14461450
assert (
14471451
conn.instance._client._client_info.user_agent

0 commit comments

Comments
 (0)