Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test


# Ignore coverage files
.coverage*
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/transaction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1._helpers import _get_retry_delay

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
Expand Down
87 changes: 87 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -460,6 +464,34 @@ def _metadata_with_prefix(prefix, **kw):
return [("google-cloud-resource-prefix", prefix)]


def _retry_on_aborted_exception(
func,
deadline,
allowed_exceptions=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can simplify this further and just remove allowed_exceptions from this function. It should only retry aborted exceptions.

):
"""
Handles retry logic for Aborted exceptions, considering the deadline.
Retries the function in case of Aborted exceptions and other allowed exceptions.
"""
attempts = 0
while True:
try:
attempts += 1
return func()
except Aborted as exc:
_delay_until_retry(exc, deadline=deadline, attempts=attempts)
continue
except Exception as exc:
try:
retry_result = _retry(func=func, allowed_exceptions=allowed_exceptions)
if retry_result is not None:
return retry_result
else:
raise exc
except Aborted:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should remove this part entirely. I know that the previous implementation of Batch retried this specific RST_STREAM error, but that was just a copy-paste from other methods. That error is not relevant for this type of operation.



def _retry(
func,
retry_count=5,
Expand All @@ -473,6 +505,7 @@ def _retry(
Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
deadline: This will be used in case of Aborted transactions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove, this is not relevant anymore

delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
Passing allowed_exceptions as None will lead to retrying for all exceptions.
Expand Down Expand Up @@ -529,6 +562,60 @@ def _metadata_with_leader_aware_routing(value, **kw):
return ("x-goog-spanner-route-to-leader", str(value).lower())


def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""

cause = exc.errors[0]
now = time.time()
if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
if hasattr(cause, "trailing_metadata"):
metadata = dict(cause.trailing_metadata())
else:
metadata = {}
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
Expand Down
23 changes: 19 additions & 4 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -162,6 +166,7 @@ def commit(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kwargs,
):
"""Commit mutations to the database.

Expand Down Expand Up @@ -227,9 +232,15 @@ def commit(
request=request,
metadata=metadata,
)
response = _retry(
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry_on_aborted_exception(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
},
deadline=deadline,
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
Expand Down Expand Up @@ -293,7 +304,9 @@ def group(self):
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
def batch_write(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_write is a bit different. I don't think we should include it in this PR, as it is a non-atomic, streaming operation, that probably needs different error handling than 'just retry if it fails with an aborted error'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. In that case, we can bypass the retry behavior for this operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also remove the **kwargs addition again from this PR. It would just be confusing if that is added in this PR, when it is not relevant to the actual change in this PR.

self, request_options=None, exclude_txn_from_change_streams=False, **kwargs
):
"""Executes batch_write.

:type request_options:
Expand Down Expand Up @@ -348,7 +361,9 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
},
)
self.committed = True
return response
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def batch(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
"""Return an object which wraps a batch.

Expand Down Expand Up @@ -805,7 +806,11 @@ def batch(
:returns: new wrapper
"""
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
self,
request_options,
max_commit_delay,
exclude_txn_from_change_streams,
**kw,
)

def mutation_groups(self):
Expand Down Expand Up @@ -1166,6 +1171,7 @@ def __init__(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
self._database = database
self._session = self._batch = None
Expand All @@ -1177,6 +1183,7 @@ def __init__(
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
self._kw = kw

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1197,6 +1204,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
**self._kw,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
58 changes: 2 additions & 56 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Wrapper for Cloud Spanner Session objects."""

from functools import total_ordering
import random
import time
from datetime import datetime

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1 import method
from google.rpc.error_details_pb2 import RetryInfo
from google.cloud.spanner_v1._helpers import _delay_until_retry
from google.cloud.spanner_v1._helpers import _get_retry_delay

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
Expand Down Expand Up @@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw):
extra={"commit_stats": txn.commit_stats},
)
return return_value


# Rational: this function factors out complex shared deadline / retry
# handling from two `except:` clauses.
def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
cause = exc.errors[0]

now = time.time()

if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
metadata = dict(cause.trailing_metadata())
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()
17 changes: 13 additions & 4 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,19 @@ def __create_transaction(
def Commit(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
if not request.transaction_id == b"":
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
tx_id = request.transaction_id
elif not request.single_use_transaction == TransactionOptions():
tx = self.__create_transaction(
request.session, request.single_use_transaction
)
tx_id = tx.id
else:
raise ValueError("Unsupported transaction type")
del self.transactions[tx_id]
return commit.CommitResponse()

def Rollback(self, request, context):
Expand Down
24 changes: 24 additions & 0 deletions tests/mockserver_tests/test_aborted_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,30 @@ def test_run_in_transaction_batch_dml_aborted(self):
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))

def test_batch_commit_aborted(self):
# Add an Aborted error for the Commit method on the mock server.
add_error(SpannerServicer.Commit.__name__, aborted_status())
with self.database.batch() as batch:
batch.insert(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(1, "Marc", "Richards"),
(2, "Catalina", "Smith"),
(3, "Alice", "Trentor"),
(4, "Lea", "Martin"),
(5, "David", "Lomond"),
],
)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], CommitRequest))
# The transaction is aborted and retried.
self.assertTrue(isinstance(requests[2], CommitRequest))


def _insert_mutations(transaction: Transaction):
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])
Expand Down
Loading
Loading