Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 198 additions & 22 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self, connection) -> None:
self._has_result_set = False # Track if we have an active result set
self._skip_increment_for_next_fetch = False # Track if we need to skip incrementing the row index

self.messages = [] # Store diagnostic messages

def _is_unicode_string(self, param):
"""
Check if a string contains non-ASCII characters.
Expand Down Expand Up @@ -453,6 +455,9 @@ def close(self) -> None:
if self.closed:
raise Exception("Cursor is already closed.")

# Clear messages per DBAPI
self.messages = []

if self.hstmt:
self.hstmt.free()
self.hstmt = None
Expand Down Expand Up @@ -698,6 +703,9 @@ def execute(
if reset_cursor:
self._reset_cursor()

# Clear any previous messages
self.messages = []

param_info = ddbc_bindings.ParamInfo
parameters_type = []

Expand Down Expand Up @@ -745,7 +753,14 @@ def execute(
self.is_stmt_prepared,
use_prepare,
)

# Check for errors but don't raise exceptions for info/warning messages
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.)
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.last_executed_stmt = operation

# Update rowcount after execution
Expand Down Expand Up @@ -827,7 +842,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
self._check_closed()
self._reset_cursor()


# Clear any previous messages
self.messages = []

if not seq_of_parameters:
self.rowcount = 0
return
Expand Down Expand Up @@ -859,6 +877,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# Capture any diagnostic messages after execution
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()
Expand All @@ -884,6 +906,9 @@ def fetchone(self) -> Union[None, Row]:
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

if ret == ddbc_sql_const.SQL_NO_DATA.value:
return None

Expand All @@ -894,8 +919,9 @@ def fetchone(self) -> Union[None, Row]:
else:
self._increment_rownumber()

# Create and return a Row object
return Row(row_data, self.description)
# Create and return a Row object, passing column name map if available
column_map = getattr(self, '_column_name_map', None)
return Row(row_data, self.description, column_map)
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -924,6 +950,10 @@ def fetchmany(self, size: int = None) -> List[Row]:
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))


# Update rownumber for the number of rows actually fetched
if rows_data and self._has_result_set:
Expand All @@ -932,7 +962,8 @@ def fetchmany(self, size: int = None) -> List[Row]:
self._rownumber = self._next_row_index - 1

# Convert raw data to Row objects
return [Row(row_data, self.description) for row_data in rows_data]
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand All @@ -952,14 +983,19 @@ def fetchall(self) -> List[Row]:
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))


# Update rownumber for the number of rows actually fetched
if rows_data and self._has_result_set:
self._next_row_index += len(rows_data)
self._rownumber = self._next_row_index - 1

# Convert raw data to Row objects
return [Row(row_data, self.description) for row_data in rows_data]
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand All @@ -976,6 +1012,9 @@ def nextset(self) -> Union[bool, None]:
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Skip to the next result set
ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
Expand Down Expand Up @@ -1056,6 +1095,9 @@ def commit(self):
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Delegate to the connection's commit method
self._connection.commit()

Expand All @@ -1082,6 +1124,9 @@ def rollback(self):
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Delegate to the connection's rollback method
self._connection.rollback()

Expand All @@ -1107,6 +1152,10 @@ def scroll(self, value: int, mode: str = 'relative') -> None:
- absolute(k>0): next fetch returns row index k (0-based); rownumber == k after scroll.
"""
self._check_closed()

# Clear messages per DBAPI
self.messages = []

if mode not in ('relative', 'absolute'):
raise ProgrammingError("Invalid scroll mode",
f"mode must be 'relative' or 'absolute', got '{mode}'")
Expand Down Expand Up @@ -1179,29 +1228,156 @@ def scroll(self, value: int, mode: str = 'relative') -> None:

def skip(self, count: int) -> None:
"""
Skip the next 'count' records in the query result set.

This is a convenience method that advances the cursor by 'count'
positions without returning the skipped rows.
Skip the next count records in the query result set.

Args:
count: Number of records to skip. Must be non-negative.

Returns:
None
count: Number of records to skip.

Raises:
ProgrammingError: If the cursor is closed or no result set is available.
NotSupportedError: If count is negative (backward scrolling not supported).
IndexError: If attempting to skip past the end of the result set.

Note:
For convenience, skip(0) is accepted and will do nothing.
ProgrammingError: If count is not an integer.
NotSupportedError: If attempting to skip backwards.
"""
from mssql_python.exceptions import ProgrammingError, NotSupportedError

self._check_closed()

if count == 0: # Skip 0 is a no-op
return
# Clear messages
self.messages = []

# Simply delegate to the scroll method with 'relative' mode
self.scroll(count, 'relative')

def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None,
table_type=None, search_escape=None):
"""
Execute SQLTables ODBC function to retrieve table metadata.

Args:
stmt_handle: ODBC statement handle
catalog_name: The catalog name pattern
schema_name: The schema name pattern
table_name: The table name pattern
table_type: The table type filter
search_escape: The escape character for pattern matching
"""
# Convert None values to empty strings for ODBC
catalog = "" if catalog_name is None else catalog_name
schema = "" if schema_name is None else schema_name
table = "" if table_name is None else table_name
types = "" if table_type is None else table_type

# Call the ODBC SQLTables function
retcode = ddbc_bindings.DDBCSQLTables(
stmt_handle,
catalog,
schema,
table,
types
)

# Check return code and handle errors
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode)

# Capture any diagnostic messages
if stmt_handle:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle))

def tables(self, table=None, catalog=None, schema=None, tableType=None):
"""
Returns information about tables in the database that match the given criteria using
the SQLTables ODBC function.

Args:
table (str, optional): The table name pattern. Default is None (all tables).
catalog (str, optional): The catalog name. Default is None.
schema (str, optional): The schema name pattern. Default is None.
tableType (str or list, optional): The table type filter. Default is None.
Example: "TABLE" or ["TABLE", "VIEW"]

Returns:
list: A list of Row objects containing table information with these columns:
- table_cat: Catalog name
- table_schem: Schema name
- table_name: Table name
- table_type: Table type (e.g., "TABLE", "VIEW")
- remarks: Comments about the table

Notes:
This method only processes the standard five columns as defined in the ODBC
specification. Any additional columns that might be returned by specific ODBC
drivers are not included in the result set.

Example:
# Get all tables in the database
tables = cursor.tables()

# Use existing scroll method with relative mode
self.scroll(count, 'relative')
# Get all tables in schema 'dbo'
tables = cursor.tables(schema='dbo')

# Get table named 'Customers'
tables = cursor.tables(table='Customers')

# Get all views
tables = cursor.tables(tableType='VIEW')
"""
self._check_closed()

# Clear messages
self.messages = []

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Format table_type parameter - SQLTables expects comma-separated string
table_type_str = None
if tableType is not None:
if isinstance(tableType, (list, tuple)):
table_type_str = ",".join(tableType)
else:
table_type_str = str(tableType)

# Call SQLTables via the helper method
self._execute_tables(
self.hstmt,
catalog_name=catalog,
schema_name=schema,
table_name=table,
table_type=table_type_str
)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, create a manual description for the standard columns
column_types = [str, str, str, str, str]
self.description = [
("table_cat", column_types[0], None, 128, 128, 0, True),
("table_schem", column_types[1], None, 128, 128, 0, True),
("table_name", column_types[2], None, 128, 128, 0, False),
("table_type", column_types[3], None, 128, 128, 0, False),
("remarks", column_types[4], None, 254, 254, 0, True)
]

# Define column names in ODBC standard order
column_names = [
"table_cat", "table_schem", "table_name", "table_type", "remarks"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(row_data, self.description, column_map)
result_rows.append(row)

return result_rows
Loading