Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Cursor:
callproc(procname, parameters=None) ->
Modified copy of the input sequence with output parameters.
close() -> None.
execute(operation, parameters=None) -> None.
execute(operation, parameters=None) -> Cursor.
executemany(operation, seq_of_parameters) -> None.
fetchone() -> Single sequence or None if no more data is available.
fetchmany(size=None) -> Sequence of sequences (e.g. list of tuples).
Expand Down Expand Up @@ -542,7 +542,7 @@ def execute(
*parameters,
use_prepare: bool = True,
reset_cursor: bool = True
) -> None:
) -> 'Cursor':
"""
Prepare and execute a database operation (query or command).

Expand Down Expand Up @@ -613,6 +613,9 @@ def execute(
# Initialize description after execution
self._initialize_description()

# Return self for method chaining
return self

@staticmethod
def _select_best_sample_value(column):
"""
Expand Down
277 changes: 270 additions & 7 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,8 @@ def test_longwvarchar(cursor, db_connection):
expectedRows = 2
# fetchone test
cursor.execute("SELECT longwvarchar_column FROM #pytest_longwvarchar_test")
rows = []
for i in range(0, expectedRows):
rows.append(cursor.fetchone())
assert cursor.fetchone() == None, "longwvarchar_column is expected to have only {} rows".format(expectedRows)
assert rows[0] == ["ABCDEFGHI"], "SQL_LONGWVARCHAR parsing failed for fetchone - row 0"
assert rows[1] == [None], "SQL_LONGWVARCHAR parsing failed for fetchone - row 1"
row = cursor.fetchone()
assert row[0] == "ABCDEFGHI", "SQL_LONGWVARCHAR parsing failed for fetchone"
# fetchall test
cursor.execute("SELECT longwvarchar_column FROM #pytest_longwvarchar_test")
rows = cursor.fetchall()
Expand Down Expand Up @@ -1313,6 +1309,274 @@ def test_row_column_mapping(cursor, db_connection):
cursor.execute("DROP TABLE #pytest_row_test")
db_connection.commit()

# Method Chaining Tests
def test_execute_returns_self(cursor):
"""Test that execute() returns the cursor itself for method chaining"""
# Test basic execute returns cursor
result = cursor.execute("SELECT 1 as test_value")
assert result is cursor, "execute() should return the cursor itself"
assert id(result) == id(cursor), "Returned cursor should be the same object"

def test_execute_fetchone_chaining(cursor, db_connection):
"""Test chaining execute() with fetchone()"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_chaining (id INT, value NVARCHAR(50))")
db_connection.commit()

# Insert test data
cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (?, ?)", 1, "test_value")
db_connection.commit()

# Test execute().fetchone() chaining
row = cursor.execute("SELECT id, value FROM #test_chaining WHERE id = ?", 1).fetchone()
assert row is not None, "Should return a row"
assert row[0] == 1, "First column should be 1"
assert row[1] == "test_value", "Second column should be 'test_value'"

# Test with non-existent row
row = cursor.execute("SELECT id, value FROM #test_chaining WHERE id = ?", 999).fetchone()
assert row is None, "Should return None for non-existent row"

finally:
try:
cursor.execute("DROP TABLE #test_chaining")
db_connection.commit()
except:
pass

def test_execute_fetchall_chaining(cursor, db_connection):
"""Test chaining execute() with fetchall()"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_chaining (id INT, value NVARCHAR(50))")
db_connection.commit()

# Insert multiple test records
cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (1, 'first')")
cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (2, 'second')")
cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (3, 'third')")
db_connection.commit()

# Test execute().fetchall() chaining
rows = cursor.execute("SELECT id, value FROM #test_chaining ORDER BY id").fetchall()
assert len(rows) == 3, "Should return 3 rows"
assert rows[0] == [1, 'first'], "First row incorrect"
assert rows[1] == [2, 'second'], "Second row incorrect"
assert rows[2] == [3, 'third'], "Third row incorrect"

# Test with WHERE clause
rows = cursor.execute("SELECT id, value FROM #test_chaining WHERE id > ?", 1).fetchall()
assert len(rows) == 2, "Should return 2 rows with WHERE clause"
assert rows[0] == [2, 'second'], "Filtered first row incorrect"
assert rows[1] == [3, 'third'], "Filtered second row incorrect"

finally:
try:
cursor.execute("DROP TABLE #test_chaining")
db_connection.commit()
except:
pass

def test_execute_fetchmany_chaining(cursor, db_connection):
"""Test chaining execute() with fetchmany()"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_chaining (id INT, value NVARCHAR(50))")
db_connection.commit()

# Insert test data
for i in range(1, 6): # Insert 5 records
cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (?, ?)", i, f"value_{i}")
db_connection.commit()

# Test execute().fetchmany() chaining with size parameter
rows = cursor.execute("SELECT id, value FROM #test_chaining ORDER BY id").fetchmany(3)
assert len(rows) == 3, "Should return 3 rows with fetchmany(3)"
assert rows[0] == [1, 'value_1'], "First row incorrect"
assert rows[1] == [2, 'value_2'], "Second row incorrect"
assert rows[2] == [3, 'value_3'], "Third row incorrect"

# Test execute().fetchmany() chaining with arraysize
cursor.arraysize = 2
rows = cursor.execute("SELECT id, value FROM #test_chaining ORDER BY id").fetchmany()
assert len(rows) == 2, "Should return 2 rows with default arraysize"
assert rows[0] == [1, 'value_1'], "First row incorrect"
assert rows[1] == [2, 'value_2'], "Second row incorrect"

finally:
try:
cursor.execute("DROP TABLE #test_chaining")
db_connection.commit()
except:
pass

def test_execute_rowcount_chaining(cursor, db_connection):
"""Test chaining execute() with rowcount property"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_chaining (id INT, value NVARCHAR(50))")
db_connection.commit()

# Test INSERT rowcount chaining
count = cursor.execute("INSERT INTO #test_chaining (id, value) VALUES (?, ?)", 1, "test").rowcount
assert count == 1, "INSERT should affect 1 row"

# Test multiple INSERT rowcount chaining
count = cursor.execute("""
INSERT INTO #test_chaining (id, value) VALUES
(2, 'test2'), (3, 'test3'), (4, 'test4')
""").rowcount
assert count == 3, "Multiple INSERT should affect 3 rows"

# Test UPDATE rowcount chaining
count = cursor.execute("UPDATE #test_chaining SET value = ? WHERE id > ?", "updated", 2).rowcount
assert count == 2, "UPDATE should affect 2 rows"

# Test DELETE rowcount chaining
count = cursor.execute("DELETE FROM #test_chaining WHERE id = ?", 1).rowcount
assert count == 1, "DELETE should affect 1 row"

# Test SELECT rowcount chaining (should be -1)
count = cursor.execute("SELECT * FROM #test_chaining").rowcount
assert count == -1, "SELECT rowcount should be -1"

finally:
try:
cursor.execute("DROP TABLE #test_chaining")
db_connection.commit()
except:
pass

def test_execute_description_chaining(cursor):
"""Test chaining execute() with description property"""
# Test description after execute
description = cursor.execute("SELECT 1 as int_col, 'test' as str_col, GETDATE() as date_col").description
assert len(description) == 3, "Should have 3 columns in description"
assert description[0][0] == "int_col", "First column name should be 'int_col'"
assert description[1][0] == "str_col", "Second column name should be 'str_col'"
assert description[2][0] == "date_col", "Third column name should be 'date_col'"

# Test description with table query
description = cursor.execute("SELECT database_id, name FROM sys.databases WHERE database_id = 1").description
assert len(description) == 2, "Should have 2 columns in description"
assert description[0][0] == "database_id", "First column should be 'database_id'"
assert description[1][0] == "name", "Second column should be 'name'"

def test_multiple_chaining_operations(cursor, db_connection):
"""Test multiple chaining operations in sequence"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_multi_chain (id INT IDENTITY(1,1), value NVARCHAR(50))")
db_connection.commit()

# Chain multiple operations: execute -> rowcount, then execute -> fetchone
insert_count = cursor.execute("INSERT INTO #test_multi_chain (value) VALUES (?)", "first").rowcount
assert insert_count == 1, "First insert should affect 1 row"

row = cursor.execute("SELECT id, value FROM #test_multi_chain WHERE value = ?", "first").fetchone()
assert row is not None, "Should find the inserted row"
assert row[1] == "first", "Value should be 'first'"

# Chain more operations
insert_count = cursor.execute("INSERT INTO #test_multi_chain (value) VALUES (?)", "second").rowcount
assert insert_count == 1, "Second insert should affect 1 row"

all_rows = cursor.execute("SELECT value FROM #test_multi_chain ORDER BY id").fetchall()
assert len(all_rows) == 2, "Should have 2 rows total"
assert all_rows[0] == ["first"], "First row should be 'first'"
assert all_rows[1] == ["second"], "Second row should be 'second'"

finally:
try:
cursor.execute("DROP TABLE #test_multi_chain")
db_connection.commit()
except:
pass

def test_chaining_with_parameters(cursor, db_connection):
"""Test method chaining with various parameter formats"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_params (id INT, name NVARCHAR(50), age INT)")
db_connection.commit()

# Test chaining with tuple parameters
row = cursor.execute("INSERT INTO #test_params VALUES (?, ?, ?)", (1, "Alice", 25)).rowcount
assert row == 1, "Tuple parameter insert should affect 1 row"

# Test chaining with individual parameters
row = cursor.execute("INSERT INTO #test_params VALUES (?, ?, ?)", 2, "Bob", 30).rowcount
assert row == 1, "Individual parameter insert should affect 1 row"

# Test chaining with list parameters
row = cursor.execute("INSERT INTO #test_params VALUES (?, ?, ?)", [3, "Charlie", 35]).rowcount
assert row == 1, "List parameter insert should affect 1 row"

# Test chaining query with parameters and fetchall
rows = cursor.execute("SELECT name, age FROM #test_params WHERE age > ?", 28).fetchall()
assert len(rows) == 2, "Should find 2 people over 28"
assert rows[0] == ["Bob", 30], "First result should be Bob"
assert rows[1] == ["Charlie", 35], "Second result should be Charlie"

finally:
try:
cursor.execute("DROP TABLE #test_params")
db_connection.commit()
except:
pass

def test_chaining_error_handling(cursor):
"""Test that chaining works properly even when errors occur"""
# Test that cursor is still chainable after an error
with pytest.raises(Exception):
cursor.execute("SELECT * FROM nonexistent_table").fetchone()

# Cursor should still be usable for chaining after error
row = cursor.execute("SELECT 1 as test").fetchone()
assert row[0] == 1, "Cursor should still work after error"

# Test chaining with invalid SQL
with pytest.raises(Exception):
cursor.execute("INVALID SQL SYNTAX").rowcount

# Should still be chainable
count = cursor.execute("SELECT COUNT(*) FROM sys.databases").fetchone()[0]
assert isinstance(count, int), "Should return integer count"
assert count > 0, "Should have at least one database"

def test_chaining_performance_statement_reuse(cursor, db_connection):
"""Test that chaining works with statement reuse (same SQL, different parameters)"""
try:
# Create test table
cursor.execute("CREATE TABLE #test_reuse (id INT, value NVARCHAR(50))")
db_connection.commit()

# Execute same SQL multiple times with different parameters (should reuse prepared statement)
sql = "INSERT INTO #test_reuse (id, value) VALUES (?, ?)"

count1 = cursor.execute(sql, 1, "first").rowcount
count2 = cursor.execute(sql, 2, "second").rowcount
count3 = cursor.execute(sql, 3, "third").rowcount

assert count1 == 1, "First insert should affect 1 row"
assert count2 == 1, "Second insert should affect 1 row"
assert count3 == 1, "Third insert should affect 1 row"

# Verify all data was inserted correctly
rows = cursor.execute("SELECT id, value FROM #test_reuse ORDER BY id").fetchall()
assert len(rows) == 3, "Should have 3 rows"
assert rows[0] == [1, "first"], "First row incorrect"
assert rows[1] == [2, "second"], "Second row incorrect"
assert rows[2] == [3, "third"], "Third row incorrect"

finally:
try:
cursor.execute("DROP TABLE #test_reuse")
db_connection.commit()
except:
pass

def test_close(db_connection):
"""Test closing the cursor"""
try:
Expand All @@ -1323,4 +1587,3 @@ def test_close(db_connection):
pytest.fail(f"Cursor close test failed: {e}")
finally:
cursor = db_connection.cursor()