diff --git a/unit_tests/test_session.py b/unit_tests/test_session.py index 37fad4570e..e00571fdef 100644 --- a/unit_tests/test_session.py +++ b/unit_tests/test_session.py @@ -749,6 +749,8 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(kw, {'some_arg': 'def'}) def test_run_in_transaction_w_timeout(self): + from google.cloud.spanner import session as MUT + from google.cloud._testing import _Monkey from google.gax.errors import GaxError from google.gax.grpc import exc_to_code from google.cloud.proto.spanner.v1.transaction_pb2 import ( @@ -779,12 +781,17 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - with self.assertRaises(GaxError) as exc: - session.run_in_transaction(unit_of_work, timeout_secs=0.01) + time_module = _FauxTimeModule() + time_module._times = [1, 1.5, 2.5] # retry once w/ timeout_secs=1 + + with _Monkey(MUT, time=time_module): + with self.assertRaises(GaxError) as exc: + session.run_in_transaction(unit_of_work, timeout_secs=1) self.assertEqual(exc_to_code(exc.exception.cause), StatusCode.ABORTED) - self.assertGreater(len(called_with), 1) + self.assertEqual(time_module._slept, None) + self.assertEqual(len(called_with), 2) for txn, args, kw in called_with: self.assertIsInstance(txn, Transaction) self.assertIsNone(txn.committed) @@ -881,9 +888,14 @@ def __init__(self, name): class _FauxTimeModule(object): _slept = None + _times = () def time(self): import time + + if len(self._times) > 0: + return self._times.pop(0) + return time.time() def sleep(self, seconds):