Skip to content

chore(x-goog-request-id): commit testing scaffold #1366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,10 @@ def __radd__(self, n):
"""
return self.__add__(n)

def reset(self):
with self.__lock:
self.__value = 0


def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)
Expand Down
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
except ImportError: # pragma: NO COVER
HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False

from google.cloud.spanner_v1._helpers import AtomicCounter

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -182,6 +183,8 @@ class Client(ClientWithProject):
SCOPE = (SPANNER_ADMIN_SCOPE,)
"""The scopes required for Google Cloud Spanner."""

NTH_CLIENT = AtomicCounter()

def __init__(
self,
project=None,
Expand Down Expand Up @@ -263,6 +266,12 @@ def __init__(
"default_transaction_options must be an instance of DefaultTransactionOptions"
)
self._default_transaction_options = default_transaction_options
self._nth_client_id = Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter(0)

@property
def _next_nth_request(self):
return self._nth_request.increment()

@property
def credentials(self):
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ def generate_rand_uint64():

def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
all_metadata = other_metadata.copy()
all_metadata = (other_metadata or []).copy()
all_metadata.append((REQ_ID_HEADER_KEY, req_id))
return all_metadata
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/testing/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.cloud.spanner_v1.testing.interceptors import (
MethodCountInterceptor,
MethodAbortInterceptor,
XGoogRequestIDHeaderInterceptor,
)


Expand All @@ -34,6 +35,8 @@ class TestDatabase(Database):
currently, and we don't want to make changes in the Database class for
testing purpose as this is a hack to use interceptors in tests."""

_interceptors = []

def __init__(
self,
database_id,
Expand Down Expand Up @@ -74,6 +77,8 @@ def spanner_api(self):
client_options = client._client_options
if self._instance.emulator_host is not None:
channel = grpc.insecure_channel(self._instance.emulator_host)
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
self._interceptors.append(self._x_goog_request_id_interceptor)
channel = grpc.intercept_channel(channel, *self._interceptors)
transport = SpannerGrpcTransport(channel=channel)
self._spanner_api = SpannerClient(
Expand Down Expand Up @@ -110,3 +115,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials):
client_options=client_options,
transport=transport,
)

def reset(self):
if self._x_goog_request_id_interceptor:
self._x_goog_request_id_interceptor.reset()
71 changes: 71 additions & 0 deletions google/cloud/spanner_v1/testing/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

from collections import defaultdict
import threading

from grpc_interceptor import ClientInterceptor
from google.api_core.exceptions import Aborted

Expand Down Expand Up @@ -63,3 +65,72 @@ def reset(self):
self._method_to_abort = None
self._count = 0
self._connection = None


X_GOOG_REQUEST_ID = "x-goog-spanner-request-id"


class XGoogRequestIDHeaderInterceptor(ClientInterceptor):
# TODO:(@odeke-em): delete this guard when PR #1367 is merged.
X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED = False

def __init__(self):
self._unary_req_segments = []
self._stream_req_segments = []
self.__lock = threading.Lock()

def intercept(self, method, request_or_iterator, call_details):
metadata = call_details.metadata
x_goog_request_id = None
for key, value in metadata:
if key == X_GOOG_REQUEST_ID:
x_goog_request_id = value
break

if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED and not x_goog_request_id:
raise Exception(
f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}"
)

response_or_iterator = method(request_or_iterator, call_details)
streaming = getattr(response_or_iterator, "__iter__", None) is not None

if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED:
with self.__lock:
if streaming:
self._stream_req_segments.append(
(call_details.method, parse_request_id(x_goog_request_id))
)
else:
self._unary_req_segments.append(
(call_details.method, parse_request_id(x_goog_request_id))
)

return response_or_iterator

@property
def unary_request_ids(self):
return self._unary_req_segments

@property
def stream_request_ids(self):
return self._stream_req_segments

def reset(self):
self._stream_req_segments.clear()
self._unary_req_segments.clear()


def parse_request_id(request_id_str):
splits = request_id_str.split(".")
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
map(lambda v: int(v), splits)
)
return (
version,
rand_process_id,
client_id,
channel_id,
nth_request,
nth_attempt,
)
7 changes: 2 additions & 5 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from google.cloud.spanner_v1 import (
TransactionOptions,
ResultSetMetadata,
ExecuteSqlRequest,
ExecuteBatchDmlRequest,
)
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
Expand Down Expand Up @@ -107,6 +105,7 @@ def CreateSession(self, request, context):

def BatchCreateSessions(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
sessions = []
for i in range(request.session_count):
sessions.append(
Expand Down Expand Up @@ -186,9 +185,7 @@ def BeginTransaction(self, request, context):
self._requests.append(request)
return self.__create_transaction(request.session, request.options)

def __maybe_create_transaction(
self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest
):
def __maybe_create_transaction(self, request):
started_transaction = None
if not request.transaction.begin == TransactionOptions():
started_transaction = self.__create_transaction(
Expand Down
5 changes: 4 additions & 1 deletion tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def setup_class(cls):
def teardown_class(cls):
if MockServerTestBase.server is not None:
MockServerTestBase.server.stop(grace=None)
Client.NTH_CLIENT.reset()
MockServerTestBase.server = None

def setup_method(self, *args, **kwargs):
Expand Down Expand Up @@ -186,6 +187,8 @@ def instance(self) -> Instance:
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
"test-database",
pool=FixedSizePool(size=10),
enable_interceptors_in_tests=True,
)
return self._database
63 changes: 63 additions & 0 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from google.cloud.spanner_v1 import TypeCode
from google.api_core.retry import Retry
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import (
AtomicCounter,
_metadata_with_request_id,
)

from tests._helpers import (
HAS_OPENTELEMETRY_INSTALLED,
Expand Down Expand Up @@ -197,6 +201,11 @@ def test_begin_ok(self):
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
# ),
],
)

Expand Down Expand Up @@ -301,6 +310,11 @@ def test_rollback_ok(self):
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
# ),
],
)

Expand Down Expand Up @@ -492,6 +506,11 @@ def _commit_helper(
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
# ),
],
)
self.assertEqual(actual_request_options, expected_request_options)
Expand Down Expand Up @@ -666,6 +685,11 @@ def _execute_update_helper(
metadata=[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
# ),
],
)

Expand Down Expand Up @@ -859,6 +883,11 @@ def _batch_update_helper(
metadata=[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
# ),
],
retry=retry,
timeout=timeout,
Expand Down Expand Up @@ -974,6 +1003,11 @@ def test_context_mgr_success(self):
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
# TODO(@odeke-em): enable with PR #1367.
# (
# "x-goog-spanner-request-id",
# f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1",
# ),
],
)

Expand Down Expand Up @@ -1004,11 +1038,19 @@ def test_context_mgr_failure(self):


class _Client(object):
NTH_CLIENT = AtomicCounter()

def __init__(self):
from google.cloud.spanner_v1 import ExecuteSqlRequest

self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1")
self.directed_read_options = None
self._nth_client_id = _Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter()

@property
def _next_nth_request(self):
return self._nth_request.increment()


class _Instance(object):
Expand All @@ -1024,6 +1066,27 @@ def __init__(self):
self._directed_read_options = None
self.default_transaction_options = DefaultTransactionOptions()

@property
def _next_nth_request(self):
return self._instance._client._next_nth_request

@property
def _nth_client_id(self):
return self._instance._client._nth_client_id

def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
return _metadata_with_request_id(
self._nth_client_id,
self._channel_id,
nth_request,
nth_attempt,
prior_metadata,
)

@property
def _channel_id(self):
return 1


class _Session(object):
_transaction = None
Expand Down
Loading