From ab768e45efe7334823ec6bcdccfac2a6dde73bd7 Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 13 Oct 2022 01:12:29 -0700 Subject: [PATCH] feat: support requiest options in !autocommit mode (#838) --- google/cloud/spanner_dbapi/_helpers.py | 12 ++++++++--- google/cloud/spanner_dbapi/connection.py | 25 +++++++++++++++-------- google/cloud/spanner_dbapi/cursor.py | 16 ++++++++++++--- tests/unit/spanner_dbapi/test__helpers.py | 8 ++++++-- tests/unit/spanner_dbapi/test_cursor.py | 23 +++++++++++++++++++++ 5 files changed, 68 insertions(+), 16 deletions(-) diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index ee4883d74f..02901ffc3a 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -47,15 +47,21 @@ } -def _execute_insert_heterogenous(transaction, sql_params_list): +def _execute_insert_heterogenous( + transaction, + sql_params_list, + request_options=None, +): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) - transaction.execute_update(sql, params, get_param_types(params)) + transaction.execute_update( + sql, params, get_param_types(params), request_options=request_options + ) def handle_insert(connection, sql, params): return connection.database.run_in_transaction( - _execute_insert_heterogenous, ((sql, params),) + _execute_insert_heterogenous, ((sql, params),), connection.request_options ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 9fa2269eae..75263400f8 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -183,6 +183,21 @@ def read_only(self, value): ) self._read_only = value + @property + def request_options(self): + """Options for the next SQL operations. + + Returns: + google.cloud.spanner_v1.RequestOptions: + Request options. + """ + if self.request_priority is None: + return + + req_opts = RequestOptions(priority=self.request_priority) + self.request_priority = None + return req_opts + @property def staleness(self): """Current read staleness option value of this `Connection`. @@ -437,25 +452,19 @@ def run_statement(self, statement, retried=False): if statement.is_insert: _execute_insert_heterogenous( - transaction, ((statement.sql, statement.params),) + transaction, ((statement.sql, statement.params),), self.request_options ) return ( iter(()), ResultsChecksum() if retried else statement.checksum, ) - if self.request_priority is not None: - req_opts = RequestOptions(priority=self.request_priority) - self.request_priority = None - else: - req_opts = None - return ( transaction.execute_sql( statement.sql, statement.params, param_types=statement.param_types, - request_options=req_opts, + request_options=self.request_options, ), ResultsChecksum() if retried else statement.checksum, ) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 4ffeac1a70..f8220d2c68 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -172,7 +172,10 @@ def close(self): def _do_execute_update(self, transaction, sql, params): result = transaction.execute_update( - sql, params=params, param_types=get_param_types(params) + sql, + params=params, + param_types=get_param_types(params), + request_options=self.connection.request_options, ) self._itr = None if type(result) == int: @@ -278,7 +281,9 @@ def execute(self, sql, args=None): _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( - self._do_execute_update, sql, args or None + self._do_execute_update, + sql, + args or None, ) except (AlreadyExists, FailedPrecondition, OutOfRange) as e: raise IntegrityError(getattr(e, "details", e)) from e @@ -421,7 +426,12 @@ def fetchmany(self, size=None): return items def _handle_DQL_with_snapshot(self, snapshot, sql, params): - self._result_set = snapshot.execute_sql(sql, params, get_param_types(params)) + self._result_set = snapshot.execute_sql( + sql, + params, + get_param_types(params), + request_options=self.connection.request_options, + ) # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. self._itr = PeekIterator(self._result_set) diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index 1782978d62..c770ff6e4b 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -37,7 +37,9 @@ def test__execute_insert_heterogenous(self): mock_pyformat.assert_called_once_with(params[0], params[1]) mock_param_types.assert_called_once_with(None) - mock_update.assert_called_once_with(sql, None, None) + mock_update.assert_called_once_with( + sql, None, None, request_options=None + ) def test__execute_insert_heterogenous_error(self): from google.cloud.spanner_dbapi import _helpers @@ -62,7 +64,9 @@ def test__execute_insert_heterogenous_error(self): mock_pyformat.assert_called_once_with(params[0], params[1]) mock_param_types.assert_called_once_with(None) - mock_update.assert_called_once_with(sql, None, None) + mock_update.assert_called_once_with( + sql, None, None, request_options=None + ) def test_handle_insert(self): from google.cloud.spanner_dbapi import _helpers diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 3f379f96ac..75089362af 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -748,6 +748,29 @@ def test_handle_dql(self): self.assertIsInstance(cursor._itr, utils.PeekIterator) self.assertEqual(cursor._row_count, _UNSET_COUNT) + def test_handle_dql_priority(self): + from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + from google.cloud.spanner_v1 import RequestOptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + connection.request_priority = 1 + + cursor = self._make_one(connection) + + sql = "sql" + mock_snapshot.execute_sql.return_value = ["0"] + cursor._handle_DQL(sql, params=None) + self.assertEqual(cursor._result_set, ["0"]) + self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + mock_snapshot.execute_sql.assert_called_with( + sql, None, None, request_options=RequestOptions(priority=1) + ) + def test_context(self): connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection)