Skip to content

Commit 411804f

Browse files
authored
test: add mock server tests (#1217)
* test: add mock server tests * chore: move to testing folder + fix formatting * refactor: move mock server tests to separate directory * feat: add database admin service Adds a DatabaseAdminService to the mock server and sets up a basic test case for this. Also removes the generated stubs in the grpc files, as these are not needed. * test: add DDL test * test: add async client tests * chore: remove async + add transaction handling * chore: cleanup * chore: run code formatter
1 parent 2225a5e commit 411804f

File tree

10 files changed

+2605
-1
lines changed

10 files changed

+2605
-1
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
on:
2+
push:
3+
branches:
4+
- main
5+
pull_request:
6+
name: Run Spanner tests against an in-mem mock server
7+
jobs:
8+
system-tests:
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- name: Checkout code
13+
uses: actions/checkout@v4
14+
- name: Setup Python
15+
uses: actions/setup-python@v5
16+
with:
17+
python-version: 3.12
18+
- name: Install nox
19+
run: python -m pip install nox
20+
- name: Run mock server tests
21+
run: nox -s mockserver

packages/google-cloud-spanner/google/cloud/spanner_v1/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class Database(object):
142142
statements in 'ddl_statements' above.
143143
"""
144144

145-
_spanner_api = None
145+
_spanner_api: SpannerClient = None
146146

147147
def __init__(
148148
self,

packages/google-cloud-spanner/google/cloud/spanner_v1/testing/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.longrunning import operations_pb2 as operations_pb2
16+
from google.protobuf import empty_pb2
17+
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
18+
19+
20+
# An in-memory mock DatabaseAdmin server that can be used for testing.
21+
class DatabaseAdminServicer(database_admin_grpc.DatabaseAdminServicer):
22+
def __init__(self):
23+
self._requests = []
24+
25+
@property
26+
def requests(self):
27+
return self._requests
28+
29+
def clear_requests(self):
30+
self._requests = []
31+
32+
def UpdateDatabaseDdl(self, request, context):
33+
self._requests.append(request)
34+
operation = operations_pb2.Operation()
35+
operation.done = True
36+
operation.name = "projects/test-project/operations/test-operation"
37+
operation.response.Pack(empty_pb2.Empty())
38+
return operation
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import base64
15+
import grpc
16+
from concurrent import futures
17+
18+
from google.protobuf import empty_pb2
19+
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
20+
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
21+
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
22+
import google.cloud.spanner_v1.types.commit_response as commit
23+
import google.cloud.spanner_v1.types.result_set as result_set
24+
import google.cloud.spanner_v1.types.spanner as spanner
25+
import google.cloud.spanner_v1.types.transaction as transaction
26+
27+
28+
class MockSpanner:
29+
def __init__(self):
30+
self.results = {}
31+
32+
def add_result(self, sql: str, result: result_set.ResultSet):
33+
self.results[sql.lower().strip()] = result
34+
35+
def get_result(self, sql: str) -> result_set.ResultSet:
36+
result = self.results.get(sql.lower().strip())
37+
if result is None:
38+
raise ValueError(f"No result found for {sql}")
39+
return result
40+
41+
def get_result_as_partial_result_sets(
42+
self, sql: str
43+
) -> [result_set.PartialResultSet]:
44+
result: result_set.ResultSet = self.get_result(sql)
45+
partials = []
46+
first = True
47+
if len(result.rows) == 0:
48+
partial = result_set.PartialResultSet()
49+
partial.metadata = result.metadata
50+
partials.append(partial)
51+
else:
52+
for row in result.rows:
53+
partial = result_set.PartialResultSet()
54+
if first:
55+
partial.metadata = result.metadata
56+
partial.values.extend(row)
57+
partials.append(partial)
58+
partials[len(partials) - 1].stats = result.stats
59+
return partials
60+
61+
62+
# An in-memory mock Spanner server that can be used for testing.
63+
class SpannerServicer(spanner_grpc.SpannerServicer):
64+
def __init__(self):
65+
self._requests = []
66+
self.session_counter = 0
67+
self.sessions = {}
68+
self.transaction_counter = 0
69+
self.transactions = {}
70+
self._mock_spanner = MockSpanner()
71+
72+
@property
73+
def mock_spanner(self):
74+
return self._mock_spanner
75+
76+
@property
77+
def requests(self):
78+
return self._requests
79+
80+
def clear_requests(self):
81+
self._requests = []
82+
83+
def CreateSession(self, request, context):
84+
self._requests.append(request)
85+
return self.__create_session(request.database, request.session)
86+
87+
def BatchCreateSessions(self, request, context):
88+
self._requests.append(request)
89+
sessions = []
90+
for i in range(request.session_count):
91+
sessions.append(
92+
self.__create_session(request.database, request.session_template)
93+
)
94+
return spanner.BatchCreateSessionsResponse(dict(session=sessions))
95+
96+
def __create_session(self, database: str, session_template: spanner.Session):
97+
self.session_counter += 1
98+
session = spanner.Session()
99+
session.name = database + "/sessions/" + str(self.session_counter)
100+
session.multiplexed = session_template.multiplexed
101+
session.labels.MergeFrom(session_template.labels)
102+
session.creator_role = session_template.creator_role
103+
self.sessions[session.name] = session
104+
return session
105+
106+
def GetSession(self, request, context):
107+
self._requests.append(request)
108+
return spanner.Session()
109+
110+
def ListSessions(self, request, context):
111+
self._requests.append(request)
112+
return [spanner.Session()]
113+
114+
def DeleteSession(self, request, context):
115+
self._requests.append(request)
116+
return empty_pb2.Empty()
117+
118+
def ExecuteSql(self, request, context):
119+
self._requests.append(request)
120+
return result_set.ResultSet()
121+
122+
def ExecuteStreamingSql(self, request, context):
123+
self._requests.append(request)
124+
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
125+
for result in partials:
126+
yield result
127+
128+
def ExecuteBatchDml(self, request, context):
129+
self._requests.append(request)
130+
response = spanner.ExecuteBatchDmlResponse()
131+
started_transaction = None
132+
if not request.transaction.begin == transaction.TransactionOptions():
133+
started_transaction = self.__create_transaction(
134+
request.session, request.transaction.begin
135+
)
136+
first = True
137+
for statement in request.statements:
138+
result = self.mock_spanner.get_result(statement.sql)
139+
if first and started_transaction is not None:
140+
result = result_set.ResultSet(
141+
self.mock_spanner.get_result(statement.sql)
142+
)
143+
result.metadata = result_set.ResultSetMetadata(result.metadata)
144+
result.metadata.transaction = started_transaction
145+
response.result_sets.append(result)
146+
return response
147+
148+
def Read(self, request, context):
149+
self._requests.append(request)
150+
return result_set.ResultSet()
151+
152+
def StreamingRead(self, request, context):
153+
self._requests.append(request)
154+
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]:
155+
yield result
156+
157+
def BeginTransaction(self, request, context):
158+
self._requests.append(request)
159+
return self.__create_transaction(request.session, request.options)
160+
161+
def __create_transaction(
162+
self, session: str, options: transaction.TransactionOptions
163+
) -> transaction.Transaction:
164+
session = self.sessions[session]
165+
if session is None:
166+
raise ValueError(f"Session not found: {session}")
167+
self.transaction_counter += 1
168+
id_bytes = bytes(
169+
f"{session.name}/transactions/{self.transaction_counter}", "UTF-8"
170+
)
171+
transaction_id = base64.urlsafe_b64encode(id_bytes)
172+
self.transactions[transaction_id] = options
173+
return transaction.Transaction(dict(id=transaction_id))
174+
175+
def Commit(self, request, context):
176+
self._requests.append(request)
177+
tx = self.transactions[request.transaction_id]
178+
if tx is None:
179+
raise ValueError(f"Transaction not found: {request.transaction_id}")
180+
del self.transactions[request.transaction_id]
181+
return commit.CommitResponse()
182+
183+
def Rollback(self, request, context):
184+
self._requests.append(request)
185+
return empty_pb2.Empty()
186+
187+
def PartitionQuery(self, request, context):
188+
self._requests.append(request)
189+
return spanner.PartitionResponse()
190+
191+
def PartitionRead(self, request, context):
192+
self._requests.append(request)
193+
return spanner.PartitionResponse()
194+
195+
def BatchWrite(self, request, context):
196+
self._requests.append(request)
197+
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
198+
yield result
199+
200+
201+
def start_mock_server() -> (grpc.Server, SpannerServicer, DatabaseAdminServicer, int):
202+
# Create a gRPC server.
203+
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
204+
205+
# Add the Spanner services to the gRPC server.
206+
spanner_servicer = SpannerServicer()
207+
spanner_grpc.add_SpannerServicer_to_server(spanner_servicer, spanner_server)
208+
database_admin_servicer = DatabaseAdminServicer()
209+
database_admin_grpc.add_DatabaseAdminServicer_to_server(
210+
database_admin_servicer, spanner_server
211+
)
212+
213+
# Start the server on a random port.
214+
port = spanner_server.add_insecure_port("[::]:0")
215+
spanner_server.start()
216+
return spanner_server, spanner_servicer, database_admin_servicer, port

0 commit comments

Comments
 (0)