Skip to content

Commit

Permalink
feat: Inline Begin transction for RW transactions (googleapis#840)
Browse files Browse the repository at this point in the history
* feat: Inline Begin transction for RW transactions

* ILB with lock for execute update and batch update

* Added lock for execute sql and read method

* fix: lint fix and testcases

* fix: lint

* fix: Set transction id along with resume token

* fix: lint

* fix: test cases

* fix: few more test case for restart on unavailable

* test: Batch update error test case

* fix: lint

* fix: Code review comments

* fix: test cases + lint

* fix: code review comments

* fix: deprecate transactionpingingpool msg

* fix: review comments

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>

* fix: Apply suggestions from code review

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>

* fix: review comments

* fix: review comment Update tests/unit/test_session.py

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
surbhigarg92 and larkee authored Dec 14, 2022
1 parent 234b21e commit c2456be
Show file tree
Hide file tree
Showing 11 changed files with 1,346 additions and 122 deletions.
7 changes: 5 additions & 2 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,6 @@ def execute_pdml():
request = ExecuteSqlRequest(
session=session.name,
sql=dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
Expand All @@ -589,7 +588,11 @@ def execute_pdml():
metadata=metadata,
)

iterator = _restart_on_unavailable(method, request)
iterator = _restart_on_unavailable(
method=method,
request=request,
transaction_selector=txn_selector,
)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
Expand Down
13 changes: 11 additions & 2 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
from google.cloud.spanner_v1 import Session
from google.cloud.spanner_v1._helpers import _metadata_with_prefix

from warnings import warn

_NOW = datetime.datetime.utcnow # unit tests may replace

Expand Down Expand Up @@ -497,6 +497,10 @@ def ping(self):
class TransactionPingingPool(PingingPool):
"""Concrete session pool implementation:
Deprecated: TransactionPingingPool no longer begins a transaction for each of its sessions at startup.
Hence the TransactionPingingPool is same as :class:`PingingPool` and maybe removed in the future.
In addition to the features of :class:`PingingPool`, this class
creates and begins a transaction for each of its sessions at startup.
Expand Down Expand Up @@ -532,6 +536,12 @@ def __init__(
labels=None,
database_role=None,
):
"""This throws a deprecation warning on initialization."""
warn(
f"{self.__class__.__name__} is deprecated.",
DeprecationWarning,
stacklevel=2,
)
self._pending_sessions = queue.Queue()

super(TransactionPingingPool, self).__init__(
Expand Down Expand Up @@ -579,7 +589,6 @@ def begin_pending_transactions(self):
"""Begin all transactions for sessions added to the pool."""
while not self._pending_sessions.empty():
session = self._pending_sessions.get()
session._transaction.begin()
super(TransactionPingingPool, self).put(session)


Expand Down
2 changes: 0 additions & 2 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,6 @@ def run_in_transaction(self, func, *args, **kw):
txn.transaction_tag = transaction_tag
else:
txn = self._transaction
if txn._transaction_id is None:
txn.begin()

try:
attempts += 1
Expand Down
121 changes: 99 additions & 22 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Model a set of read-only queries to a database as a snapshot."""

import functools

import threading
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import ReadRequest
Expand All @@ -27,6 +27,7 @@

from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import ServiceUnavailable
from google.api_core.exceptions import InvalidArgument
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
Expand All @@ -43,7 +44,13 @@


def _restart_on_unavailable(
method, request, trace_name=None, session=None, attributes=None
method,
request,
trace_name=None,
session=None,
attributes=None,
transaction=None,
transaction_selector=None,
):
"""Restart iteration after :exc:`.ServiceUnavailable`.
Expand All @@ -52,22 +59,51 @@ def _restart_on_unavailable(
:type request: proto
:param request: request proto to call the method with
:type transaction: :class:`google.cloud.spanner_v1.snapshot._SnapshotBase`
:param transaction: Snapshot or Transaction class object based on the type of transaction
:type transaction_selector: :class:`transaction_pb2.TransactionSelector`
:param transaction_selector: Transaction selector object to be used in request if transaction is not passed,
if both transaction_selector and transaction are passed, then transaction is given priority.
"""

resume_token = b""
item_buffer = []

if transaction is not None:
transaction_selector = transaction._make_txn_selector()
elif transaction_selector is None:
raise InvalidArgument(
"Either transaction or transaction_selector should be set"
)

request.transaction = transaction_selector
with trace_call(trace_name, session, attributes):
iterator = method(request=request)
while True:
try:
for item in iterator:
item_buffer.append(item)
# Setting the transaction id because the transaction begin was inlined for first rpc.
if (
transaction is not None
and transaction._transaction_id is None
and item.metadata is not None
and item.metadata.transaction is not None
and item.metadata.transaction.id is not None
):
transaction._transaction_id = item.metadata.transaction.id
if item.resume_token:
resume_token = item.resume_token
break
except ServiceUnavailable:
del item_buffer[:]
with trace_call(trace_name, session, attributes):
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
request.transaction = transaction_selector
iterator = method(request=request)
continue
except InternalServerError as exc:
Expand All @@ -80,6 +116,9 @@ def _restart_on_unavailable(
del item_buffer[:]
with trace_call(trace_name, session, attributes):
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
request.transaction = transaction_selector
iterator = method(request=request)
continue

Expand All @@ -106,6 +145,7 @@ class _SnapshotBase(_SessionWrapper):
_transaction_id = None
_read_request_count = 0
_execute_sql_count = 0
_lock = threading.Lock()

def _make_txn_selector(self):
"""Helper for :meth:`read` / :meth:`execute_sql`.
Expand Down Expand Up @@ -180,13 +220,12 @@ def read(
if self._read_request_count > 0:
if not self._multi_use:
raise ValueError("Cannot re-use single-use snapshot.")
if self._transaction_id is None:
if self._transaction_id is None and self._read_only:
raise ValueError("Transaction ID pending.")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
transaction = self._make_txn_selector()

if request_options is None:
request_options = RequestOptions()
Expand All @@ -204,7 +243,6 @@ def read(
table=table,
columns=columns,
key_set=keyset._to_pb(),
transaction=transaction,
index=index,
limit=limit,
partition_token=partition,
Expand All @@ -219,13 +257,32 @@ def read(
)

trace_attributes = {"table_id": table, "columns": columns}
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
)

if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
with self._lock:
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
transaction=self,
)
self._read_request_count += 1
if self._multi_use:
return StreamedResultSet(iterator, source=self)
else:
return StreamedResultSet(iterator)
else:
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
transaction=self,
)

self._read_request_count += 1

Expand Down Expand Up @@ -301,7 +358,7 @@ def execute_sql(
if self._read_request_count > 0:
if not self._multi_use:
raise ValueError("Cannot re-use single-use snapshot.")
if self._transaction_id is None:
if self._transaction_id is None and self._read_only:
raise ValueError("Transaction ID pending.")

if params is not None:
Expand All @@ -315,7 +372,7 @@ def execute_sql(

database = self._session._database
metadata = _metadata_with_prefix(database.name)
transaction = self._make_txn_selector()

api = database.spanner_api

# Query-level options have higher precedence than client-level and
Expand All @@ -336,7 +393,6 @@ def execute_sql(
request = ExecuteSqlRequest(
session=self._session.name,
sql=sql,
transaction=transaction,
params=params_pb,
param_types=param_types,
query_mode=query_mode,
Expand All @@ -354,13 +410,34 @@ def execute_sql(
)

trace_attributes = {"db.statement": sql}
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
)

if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
with self._lock:
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
transaction=self,
)
self._read_request_count += 1
self._execute_sql_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self)
else:
return StreamedResultSet(iterator)
else:
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
transaction=self,
)

self._read_request_count += 1
self._execute_sql_count += 1
Expand Down
Loading

0 comments on commit c2456be

Please sign in to comment.