Skip to content

Commit

Permalink
feat: support requiest options in !autocommit mode (googleapis#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaFaer authored Oct 13, 2022
1 parent 06725fc commit ab768e4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 16 deletions.
12 changes: 9 additions & 3 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@
}


def _execute_insert_heterogenous(transaction, sql_params_list):
def _execute_insert_heterogenous(
transaction,
sql_params_list,
request_options=None,
):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
transaction.execute_update(sql, params, get_param_types(params))
transaction.execute_update(
sql, params, get_param_types(params), request_options=request_options
)


def handle_insert(connection, sql, params):
return connection.database.run_in_transaction(
_execute_insert_heterogenous, ((sql, params),)
_execute_insert_heterogenous, ((sql, params),), connection.request_options
)


Expand Down
25 changes: 17 additions & 8 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ def read_only(self, value):
)
self._read_only = value

@property
def request_options(self):
"""Options for the next SQL operations.
Returns:
google.cloud.spanner_v1.RequestOptions:
Request options.
"""
if self.request_priority is None:
return

req_opts = RequestOptions(priority=self.request_priority)
self.request_priority = None
return req_opts

@property
def staleness(self):
"""Current read staleness option value of this `Connection`.
Expand Down Expand Up @@ -437,25 +452,19 @@ def run_statement(self, statement, retried=False):

if statement.is_insert:
_execute_insert_heterogenous(
transaction, ((statement.sql, statement.params),)
transaction, ((statement.sql, statement.params),), self.request_options
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

if self.request_priority is not None:
req_opts = RequestOptions(priority=self.request_priority)
self.request_priority = None
else:
req_opts = None

return (
transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=req_opts,
request_options=self.request_options,
),
ResultsChecksum() if retried else statement.checksum,
)
Expand Down
16 changes: 13 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def close(self):

def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
sql,
params=params,
param_types=get_param_types(params),
request_options=self.connection.request_options,
)
self._itr = None
if type(result) == int:
Expand Down Expand Up @@ -278,7 +281,9 @@ def execute(self, sql, args=None):
_helpers.handle_insert(self.connection, sql, args or None)
else:
self.connection.database.run_in_transaction(
self._do_execute_update, sql, args or None
self._do_execute_update,
sql,
args or None,
)
except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
raise IntegrityError(getattr(e, "details", e)) from e
Expand Down Expand Up @@ -421,7 +426,12 @@ def fetchmany(self, size=None):
return items

def _handle_DQL_with_snapshot(self, snapshot, sql, params):
self._result_set = snapshot.execute_sql(sql, params, get_param_types(params))
self._result_set = snapshot.execute_sql(
sql,
params,
get_param_types(params),
request_options=self.connection.request_options,
)
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
self._itr = PeekIterator(self._result_set)
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/spanner_dbapi/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test__execute_insert_heterogenous(self):

mock_pyformat.assert_called_once_with(params[0], params[1])
mock_param_types.assert_called_once_with(None)
mock_update.assert_called_once_with(sql, None, None)
mock_update.assert_called_once_with(
sql, None, None, request_options=None
)

def test__execute_insert_heterogenous_error(self):
from google.cloud.spanner_dbapi import _helpers
Expand All @@ -62,7 +64,9 @@ def test__execute_insert_heterogenous_error(self):

mock_pyformat.assert_called_once_with(params[0], params[1])
mock_param_types.assert_called_once_with(None)
mock_update.assert_called_once_with(sql, None, None)
mock_update.assert_called_once_with(
sql, None, None, request_options=None
)

def test_handle_insert(self):
from google.cloud.spanner_dbapi import _helpers
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,29 @@ def test_handle_dql(self):
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

def test_handle_dql_priority(self):
from google.cloud.spanner_dbapi import utils
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
from google.cloud.spanner_v1 import RequestOptions

connection = self._make_connection(self.INSTANCE, mock.MagicMock())
connection.database.snapshot.return_value.__enter__.return_value = (
mock_snapshot
) = mock.MagicMock()
connection.request_priority = 1

cursor = self._make_one(connection)

sql = "sql"
mock_snapshot.execute_sql.return_value = ["0"]
cursor._handle_DQL(sql, params=None)
self.assertEqual(cursor._result_set, ["0"])
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)
mock_snapshot.execute_sql.assert_called_with(
sql, None, None, request_options=RequestOptions(priority=1)
)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
Expand Down

0 comments on commit ab768e4

Please sign in to comment.