Skip to content

feat: Add support and tests for DML returning clauses #805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_v1 import param_types


Expand Down Expand Up @@ -47,24 +45,6 @@
}


def _execute_insert_heterogenous(
transaction,
sql_params_list,
request_options=None,
):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
transaction.execute_update(
sql, params, get_param_types(params), request_options=request_options
)


def handle_insert(connection, sql, params):
return connection.database.run_in_transaction(
_execute_insert_heterogenous, ((sql, params),), connection.request_options
)


class ColumnInfo:
"""Row column description object."""

Expand Down
10 changes: 0 additions & 10 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -450,15 +449,6 @@ def run_statement(self, statement, retried=False):
if not retried:
self._statements.append(statement)

if statement.is_insert:
_execute_insert_heterogenous(
transaction, ((statement.sql, statement.params),), self.request_options
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
statement.sql,
Expand Down
48 changes: 26 additions & 22 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")
Statement = namedtuple("Statement", "sql, params, param_types, checksum")


def check_not_closed(function):
Expand Down Expand Up @@ -137,14 +137,21 @@ def description(self):

@property
def rowcount(self):
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
"""The number of rows updated by the last INSERT, UPDATE, DELETE request's `execute()` call.
For SELECT requests the rowcount returns -1.

:rtype: int
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
:returns: The number of rows updated by the last INSERT, UPDATE, DELETE request's .execute*() call.
"""

return self._row_count
if self._row_count != _UNSET_COUNT or self._result_set is None:
return self._row_count

stats = getattr(self._result_set, "stats", None)
if stats is not None and "row_count_exact" in stats:
return stats.row_count_exact

return _UNSET_COUNT

@check_not_closed
def callproc(self, procname, args=None):
Expand All @@ -171,17 +178,11 @@ def close(self):
self._is_closed = True

def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql,
params=params,
param_types=get_param_types(params),
request_options=self.connection.request_options,
self._result_set = transaction.execute_sql(
sql, params=params, param_types=get_param_types(params)
)
self._itr = None
if type(result) == int:
self._row_count = result

return result
self._itr = PeekIterator(self._result_set)
self._row_count = _UNSET_COUNT

def _do_batch_update(self, transaction, statements, many_result_set):
status, res = transaction.batch_update(statements)
Expand Down Expand Up @@ -227,7 +228,9 @@ def execute(self, sql, args=None):
:type args: list
:param args: Additional parameters to supplement the SQL query.
"""
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT

try:
if self.connection.read_only:
Expand All @@ -249,18 +252,14 @@ def execute(self, sql, args=None):
if class_ == parse_utils.STMT_UPDATING:
sql = parse_utils.ensure_where_clause(sql)

if class_ != parse_utils.STMT_INSERT:
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if not self.connection.autocommit:
statement = Statement(
sql,
args,
get_param_types(args or None)
if class_ != parse_utils.STMT_INSERT
else {},
get_param_types(args or None),
ResultsChecksum(),
class_ == parse_utils.STMT_INSERT,
)

(
Expand All @@ -277,8 +276,6 @@ def execute(self, sql, args=None):

if class_ == parse_utils.STMT_NON_UPDATING:
self._handle_DQL(sql, args or None)
elif class_ == parse_utils.STMT_INSERT:
_helpers.handle_insert(self.connection, sql, args or None)
else:
self.connection.database.run_in_transaction(
self._do_execute_update,
Expand All @@ -304,6 +301,10 @@ def executemany(self, operation, seq_of_params):
:param seq_of_params: Sequence of additional parameters to run
the query with.
"""
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT

class_ = parse_utils.classify_stmt(operation)
if class_ == parse_utils.STMT_DDL:
raise ProgrammingError(
Expand All @@ -327,6 +328,7 @@ def executemany(self, operation, seq_of_params):
)
else:
retried = False
total_row_count = 0
while True:
try:
transaction = self.connection.transaction_checkout()
Expand All @@ -341,12 +343,14 @@ def executemany(self, operation, seq_of_params):
many_result_set.add_iter(res)
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
total_row_count += sum([max(val, 0) for val in res])

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
self._row_count = total_row_count
break
except Aborted:
self.connection.retry_transaction()
Expand Down
3 changes: 3 additions & 0 deletions tests/system/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
INSTANCE_ID_DEFAULT = "google-cloud-python-systest"
INSTANCE_ID = os.environ.get(INSTANCE_ID_ENVVAR, INSTANCE_ID_DEFAULT)

API_ENDPOINT_ENVVAR = "GOOGLE_CLOUD_TESTS_SPANNER_HOST"
API_ENDPOINT = os.getenv(API_ENDPOINT_ENVVAR)

SKIP_BACKUP_TESTS_ENVVAR = "SKIP_BACKUP_TESTS"
SKIP_BACKUP_TESTS = os.getenv(SKIP_BACKUP_TESTS_ENVVAR) is not None

Expand Down
5 changes: 4 additions & 1 deletion tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def spanner_client():
credentials=credentials,
)
else:
return spanner_v1.Client() # use google.auth.default credentials
client_options = {"api_endpoint": _helpers.API_ENDPOINT}
return spanner_v1.Client(
client_options=client_options
) # use google.auth.default credentials


@pytest.fixture(scope="session")
Expand Down
123 changes: 123 additions & 0 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,126 @@ def test_staleness(shared_instance, dbapi_database):
assert len(cursor.fetchall()) == 1

conn.close()


@pytest.mark.parametrize("autocommit", [False, True])
def test_rowcount(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()

cur.execute(
"""
CREATE TABLE Singers (
SingerId INT64 NOT NULL,
Name STRING(1024),
) PRIMARY KEY (SingerId)
"""
)
conn.commit()

# executemany sets rowcount to the total modified rows
rows = [(i, f"Singer {i}") for i in range(100)]
cur.executemany("INSERT INTO Singers (SingerId, Name) VALUES (%s, %s)", rows[:98])
assert cur.rowcount == 98

# execute with INSERT
cur.execute(
"INSERT INTO Singers (SingerId, Name) VALUES (%s, %s), (%s, %s)",
[x for row in rows[98:] for x in row],
)
assert cur.rowcount == 2

# execute with UPDATE
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
assert cur.rowcount == 25

# execute with SELECT
cur.execute("SELECT Name FROM Singers WHERE SingerId < 75")
assert len(cur.fetchall()) == 75
# rowcount is not available for SELECT
assert cur.rowcount == -1

# execute with DELETE
cur.execute("DELETE FROM Singers")
assert cur.rowcount == 100

# execute with UPDATE matching 0 rows
cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25")
assert cur.rowcount == 0

conn.commit()
cur.execute("DROP TABLE Singers")
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_insert(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_update(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
"""
)
assert cur.rowcount == 1
cur.execute(
"""
UPDATE contacts SET first_name = 'new-name' WHERE contact_id = 1
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "new-name")
assert cur.rowcount == 1
conn.commit()


@pytest.mark.parametrize("autocommit", [False, True])
@pytest.mark.skipif(
_helpers.USE_EMULATOR, reason="Emulator does not support DML Returning."
)
def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
conn = Connection(shared_instance, dbapi_database)
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com')
"""
)
assert cur.rowcount == 1
cur.execute(
"""
DELETE FROM contacts WHERE contact_id = 1
THEN RETURN contact_id, first_name
"""
)
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()
Loading