Skip to content

Commit

Permalink
feat: Add support and tests for DML returning clauses (googleapis#805)
Browse files Browse the repository at this point in the history
This change adds support for DML returning clauses and includes a few prerequisite changes.

I would suggest reviewing commit-by-commit. The commit messages provide additional context and are reproduced below,

### feat: Support custom endpoint when running tests

By setting the `GOOGLE_CLOUD_TESTS_SPANNER_HOST` environment variable you can now run tests against an alternate Spanner API endpoint. This is particularly useful for running system tests against a pre-production deployment.

### refactor(dbapi): Remove most special handling of INSERTs

For historical reasons it seems the INSERT codepath and that for UPDATE/DELETE were separated, but today there appears to be no practical differences in how these DML statements are handled. This change removes most of the special handling for INSERTs and uses existing methods for UPDATEs/DELETEs instead. The one remaining exception is the automatic addition of a WHERE clause to UPDATE and DELETE statements lacking one, which does not apply to INSERT statements.

### feat(dbapi): Add full support for rowcount

Previously, rowcount was only available after executing an UPDATE or DELETE in autocommit mode. This change extends this support so that a rowcount is available for all DML statements, regardless of whether autocommit is enabled.

### feat: Add support for returning clause in DML

This change adds support and tests for a returning clause in DML statements. This is done by moving executing of all DML to use `execute_sql`, which is already used when not in autocommit mode.
  • Loading branch information
c2nes authored Nov 22, 2022
1 parent 1922a2e commit 81505cd
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 178 deletions.
20 changes: 0 additions & 20 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_v1 import param_types


Expand Down Expand Up @@ -47,24 +45,6 @@
}


def _execute_insert_heterogenous(
transaction,
sql_params_list,
request_options=None,
):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
transaction.execute_update(
sql, params, get_param_types(params), request_options=request_options
)


def handle_insert(connection, sql, params):
return connection.database.run_in_transaction(
_execute_insert_heterogenous, ((sql, params),), connection.request_options
)


class ColumnInfo:
"""Row column description object."""

Expand Down
10 changes: 0 additions & 10 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -450,15 +449,6 @@ def run_statement(self, statement, retried=False):
if not retried:
self._statements.append(statement)

if statement.is_insert:
_execute_insert_heterogenous(
transaction, ((statement.sql, statement.params),), self.request_options
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
statement.sql,
Expand Down
48 changes: 26 additions & 22 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")
Statement = namedtuple("Statement", "sql, params, param_types, checksum")


def check_not_closed(function):
Expand Down Expand Up @@ -137,14 +137,21 @@ def description(self):

@property
def rowcount(self):
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
"""The number of rows updated by the last INSERT, UPDATE, DELETE request's `execute()` call.
For SELECT requests the rowcount returns -1.
:rtype: int
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
:returns: The number of rows updated by the last INSERT, UPDATE, DELETE request's .execute*() call.
"""

return self._row_count
if self._row_count != _UNSET_COUNT or self._result_set is None:
return self._row_count

stats = getattr(self._result_set, "stats", None)
if stats is not None and "row_count_exact" in stats:
return stats.row_count_exact

return _UNSET_COUNT

@check_not_closed
def callproc(self, procname, args=None):
Expand All @@ -171,17 +178,11 @@ def close(self):
self._is_closed = True

def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql,
params=params,
param_types=get_param_types(params),
request_options=self.connection.request_options,
self._result_set = transaction.execute_sql(
sql, params=params, param_types=get_param_types(params)
)
self._itr = None
if type(result) == int:
self._row_count = result

return result
self._itr = PeekIterator(self._result_set)
self._row_count = _UNSET_COUNT

def _do_batch_update(self, transaction, statements, many_result_set):
status, res = transaction.batch_update(statements)
Expand Down Expand Up @@ -227,7 +228,9 @@ def execute(self, sql, args=None):
:type args: list
:param args: Additional parameters to supplement the SQL query.
"""
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT

try:
if self.connection.read_only:
Expand All @@ -249,18 +252,14 @@ def execute(self, sql, args=None):
if class_ == parse_utils.STMT_UPDATING:
sql = parse_utils.ensure_where_clause(sql)

if class_ != parse_utils.STMT_INSERT:
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
statement = Statement(
sql,
args,
get_param_types(args or None)
if class_ != parse_utils.STMT_INSERT
else {},
get_param_types(args or None),
ResultsChecksum(),
class_ == parse_utils.STMT_INSERT,
)

(
Expand All @@ -277,8 +276,6 @@ def execute(self, sql, args=None):

if class_ == parse_utils.STMT_NON_UPDATING:
self._handle_DQL(sql, args or None)
elif class_ == parse_utils.STMT_INSERT:
_helpers.handle_insert(self.connection, sql, args or None)
else:
self.connection.database.run_in_transaction(
self._do_execute_update,
Expand All @@ -304,6 +301,10 @@ def executemany(self, operation, seq_of_params):
:param seq_of_params: Sequence of additional parameters to run
the query with.
"""
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT

class_ = parse_utils.classify_stmt(operation)
if class_ == parse_utils.STMT_DDL:
raise ProgrammingError(
Expand All @@ -327,6 +328,7 @@ def executemany(self, operation, seq_of_params):
)
else:
retried = False
total_row_count = 0
while True:
try:
transaction = self.connection.transaction_checkout()
Expand All @@ -341,12 +343,14 @@ def executemany(self, operation, seq_of_params):
many_result_set.add_iter(res)
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
total_row_count += sum([max(val, 0) for val in res])

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
self._row_count = total_row_count
break
except Aborted:
self.connection.retry_transaction()
Expand Down
3 changes: 3 additions & 0 deletions tests/system/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
INSTANCE_ID_DEFAULT = "google-cloud-python-systest"
INSTANCE_ID = os.environ.get(INSTANCE_ID_ENVVAR, INSTANCE_ID_DEFAULT)

API_ENDPOINT_ENVVAR = "GOOGLE_CLOUD_TESTS_SPANNER_HOST"
API_ENDPOINT = os.getenv(API_ENDPOINT_ENVVAR)

SKIP_BACKUP_TESTS_ENVVAR = "SKIP_BACKUP_TESTS"
SKIP_BACKUP_TESTS = os.getenv(SKIP_BACKUP_TESTS_ENVVAR) is not None

Expand Down
5 changes: 4 additions & 1 deletion tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def spanner_client():
credentials=credentials,
)
else:
return spanner_v1.Client() # use google.auth.default credentials
client_options = {"api_endpoint": _helpers.API_ENDPOINT}
return spanner_v1.Client(
client_options=client_options
) # use google.auth.default credentials


@pytest.fixture(scope="session")
Expand Down
123 changes: 123 additions & 0 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,126 @@ def test_staleness(shared_instance, dbapi_database):
assert len(cursor.fetchall()) == 1

conn.close()


@pytest.mark.parametrize("autocommit", [False, True])
def test_rowcount(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()

cur.execute(
"""
CREATE TABLE Singers (
SingerId INT64 NOT NULL,
Name STRING(1024),
) PRIMARY KEY (SingerId)
"""
)
conn.commit()

# executemany sets rowcount to the total modified rows
rows = [(i, f"Singer {i}") for i in range(100)]
cur.executemany("INSERT INTO Singers (SingerId, Name) VALUES (%s, %s)", rows[:98])
assert cur.rowcount == 98

# execute with INSERT
cur.execute(
"INSERT INTO Singers (SingerId, Name) VALUES (%s, %s), (%s, %s)",
[x for row in rows[98:] for x in row],
)
assert cur.rowcount == 2

# execute with UPDATE
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
assert cur.rowcount == 25

# execute with SELECT
cur.execute("SELECT Name FROM Singers WHERE SingerId < 75")
assert len(cur.fetchall()) == 75
# rowcount is not available for SELECT
assert cur.rowcount == -1

# execute with DELETE
cur.execute("DELETE FROM Singers")
assert cur.rowcount == 100

# execute with UPDATE matching 0 rows
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
assert cur.rowcount == 0

conn.commit()
cur.execute("DROP TABLE Singers")
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_insert(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_update(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
"""
)
assert cur.rowcount == 1
cur.execute(
"""
UPDATE contacts SET first_name = 'new-name' WHERE contact_id = 1
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "new-name")
assert cur.rowcount == 1
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
"""
)
assert cur.rowcount == 1
cur.execute(
"""
DELETE FROM contacts WHERE contact_id = 1
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()
Loading

0 comments on commit 81505cd

Please sign in to comment.