Skip to content

Commit

Permalink
Add a test which provokes abort-during-read during 'run_in_transacti…
Browse files Browse the repository at this point in the history
…on'. (googleapis#3663)
  • Loading branch information
tseaver authored and landrito committed Aug 22, 2017
1 parent cee0672 commit 94ce6dd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 21 deletions.
16 changes: 11 additions & 5 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# pylint: disable=ungrouped-imports
from google.cloud.exceptions import NotFound
from google.cloud.exceptions import GrpcRendezvous
from google.cloud.spanner._helpers import _options_with_prefix
from google.cloud.spanner.batch import Batch
from google.cloud.spanner.snapshot import Snapshot
Expand Down Expand Up @@ -286,7 +287,7 @@ def run_in_transaction(self, func, *args, **kw):
txn.begin()
try:
return_value = func(txn, *args, **kw)
except GaxError as exc:
except (GaxError, GrpcRendezvous) as exc:
_delay_until_retry(exc, deadline)
del self._transaction
continue
Expand Down Expand Up @@ -318,15 +319,20 @@ def _delay_until_retry(exc, deadline):
:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.
"""
if exc_to_code(exc.cause) != StatusCode.ABORTED:
if isinstance(exc, GrpcRendezvous): # pragma: NO COVER see #3663
cause = exc
else:
cause = exc.cause

if exc_to_code(cause) != StatusCode.ABORTED:
raise

now = time.time()

if now >= deadline:
raise

delay = _get_retry_delay(exc)
delay = _get_retry_delay(cause)
if delay is not None:

if now + delay > deadline:
Expand All @@ -336,7 +342,7 @@ def _delay_until_retry(exc, deadline):
# pylint: enable=misplaced-bare-raise


def _get_retry_delay(exc):
def _get_retry_delay(cause):
"""Helper for :func:`_delay_until_retry`.
:type exc: :class:`google.gax.errors.GaxError`
Expand All @@ -345,7 +351,7 @@ def _get_retry_delay(exc):
:rtype: float
:returns: seconds to wait before retrying the transaction.
"""
metadata = dict(exc.cause.trailing_metadata())
metadata = dict(cause.trailing_metadata())
retry_info_pb = metadata.get('google.rpc.retryinfo-bin')
if retry_info_pb is not None:
retry_info = RetryInfo()
Expand Down
119 changes: 103 additions & 16 deletions spanner/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
'google-cloud-python-systest')
DATABASE_ID = 'test_database'
EXISTING_INSTANCES = []
COUNTERS_TABLE = 'counters'
COUNTERS_COLUMNS = ('name', 'value')


class Config(object):
Expand Down Expand Up @@ -360,11 +362,6 @@ class TestSessionAPI(unittest.TestCase, _TestData):
'description',
'exactly_hwhen',
)
COUNTERS_TABLE = 'counters'
COUNTERS_COLUMNS = (
'name',
'value',
)
SOME_DATE = datetime.date(2011, 1, 17)
SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612)
NANO_TIME = TimestampWithNanoseconds(1995, 8, 31, nanosecond=987654321)
Expand Down Expand Up @@ -554,9 +551,7 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey):

with session.batch() as batch:
batch.insert_or_update(
self.COUNTERS_TABLE,
self.COUNTERS_COLUMNS,
[[pkey, INITIAL_VALUE]])
COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, INITIAL_VALUE]])

# We don't want to run the threads' transactions in the current
# session, which would fail.
Expand All @@ -582,21 +577,19 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey):

keyset = KeySet(keys=[(pkey,)])
rows = list(session.read(
self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset))
COUNTERS_TABLE, COUNTERS_COLUMNS, keyset))
self.assertEqual(len(rows), 1)
_, value = rows[0]
self.assertEqual(value, INITIAL_VALUE + len(threads))

def _read_w_concurrent_update(self, transaction, pkey):
keyset = KeySet(keys=[(pkey,)])
rows = list(transaction.read(
self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset))
COUNTERS_TABLE, COUNTERS_COLUMNS, keyset))
self.assertEqual(len(rows), 1)
pkey, value = rows[0]
transaction.update(
self.COUNTERS_TABLE,
self.COUNTERS_COLUMNS,
[[pkey, value + 1]])
COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]])

def test_transaction_read_w_concurrent_updates(self):
PKEY = 'read_w_concurrent_updates'
Expand All @@ -613,15 +606,48 @@ def _query_w_concurrent_update(self, transaction, pkey):
self.assertEqual(len(rows), 1)
pkey, value = rows[0]
transaction.update(
self.COUNTERS_TABLE,
self.COUNTERS_COLUMNS,
[[pkey, value + 1]])
COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]])

def test_transaction_query_w_concurrent_updates(self):
PKEY = 'query_w_concurrent_updates'
self._transaction_concurrency_helper(
self._query_w_concurrent_update, PKEY)

def test_transaction_read_w_abort(self):

retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

session = self._db.session()
session.create()

trigger = _ReadAbortTrigger()

with session.batch() as batch:
batch.delete(COUNTERS_TABLE, self.ALL)
batch.insert(
COUNTERS_TABLE,
COUNTERS_COLUMNS,
[[trigger.KEY1, 0], [trigger.KEY2, 0]])

provoker = threading.Thread(
target=trigger.provoke_abort, args=(self._db,))
handler = threading.Thread(
target=trigger.handle_abort, args=(self._db,))

provoker.start()
trigger.provoker_started.wait()

handler.start()
trigger.handler_done.wait()

provoker.join()
handler.join()

rows = list(session.read(COUNTERS_TABLE, COUNTERS_COLUMNS, self.ALL))
self._check_row_data(
rows, expected=[[trigger.KEY1, 1], [trigger.KEY2, 1]])

@staticmethod
def _row_data(max_index):
for index in range(max_index):
Expand Down Expand Up @@ -1103,3 +1129,64 @@ def __init__(self, db):

def delete(self):
self._db.drop()


class _ReadAbortTrigger(object):
"""Helper for tests provoking abort-during-read."""

KEY1 = 'key1'
KEY2 = 'key2'

def __init__(self):
self.provoker_started = threading.Event()
self.provoker_done = threading.Event()
self.handler_running = threading.Event()
self.handler_done = threading.Event()

def _provoke_abort_unit_of_work(self, transaction):
keyset = KeySet(keys=[(self.KEY1,)])
rows = list(
transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset))

assert len(rows) == 1
row = rows[0]
value = row[1]

self.provoker_started.set()

self.handler_running.wait()

transaction.update(
COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY1, value + 1]])

def provoke_abort(self, database):
database.run_in_transaction(self._provoke_abort_unit_of_work)
self.provoker_done.set()

def _handle_abort_unit_of_work(self, transaction):
keyset_1 = KeySet(keys=[(self.KEY1,)])
rows_1 = list(
transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_1))

assert len(rows_1) == 1
row_1 = rows_1[0]
value_1 = row_1[1]

self.handler_running.set()

self.provoker_done.wait()

keyset_2 = KeySet(keys=[(self.KEY2,)])
rows_2 = list(
transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_2))

assert len(rows_2) == 1
row_2 = rows_2[0]
value_2 = row_2[1]

transaction.update(
COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY2, value_1 + value_2]])

def handle_abort(self, database):
database.run_in_transaction(self._handle_abort_unit_of_work)
self.handler_done.set()

0 comments on commit 94ce6dd

Please sign in to comment.