Skip to content

Commit 3d8304e

Browse files
authored
FEAT: Adding Cursor.messages (#184)
### Work Item / Issue Reference <!-- IMPORTANT: Please follow the PR template guidelines below. For mssql-python maintainers: Insert your ADO Work Item ID below (e.g. AB#37452) For external contributors: Insert Github Issue number below (e.g. #149) Only one reference is required - either GitHub issue OR ADO Work Item. --> <!-- mssql-python maintainers: ADO Work Item --> > [AB#34893](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34893) ------------------------------------------------------------------- ### Summary This pull request adds comprehensive support for capturing and managing diagnostic messages (such as SQL Server PRINT statements and warnings) in the `mssql_python` driver's `Cursor` class, following the DBAPI specification. It introduces a new `messages` attribute on the cursor, ensures messages are cleared or preserved at the correct times, and provides robust testing for these behaviors. Additionally, it implements the underlying C++ binding for retrieving all diagnostic records from the ODBC driver. **Diagnostic message handling improvements:** * Added a `messages` attribute to the `Cursor` class to store diagnostic messages, and ensured it is cleared before each non-fetch operation (e.g., `execute`, `executemany`, `close`, `commit`, `rollback`, `scroll`, and `nextset`) to comply with DBAPI expectations. (`mssql_python/cursor.py`) * After each statement execution and fetch operation, diagnostic messages (including informational and warning messages) are collected and appended to the `messages` list, using a new C++ binding. (`mssql_python/cursor.py`, `mssql_python/pybind/ddbc_bindings.cpp`) **Native driver and binding enhancements:** * Implemented the `SQLGetAllDiagRecords` function in the C++ pybind layer to retrieve all diagnostic records from an ODBC statement handle, handling both Windows and Unix platforms, and exposed it as `DDBCSQLGetAllDiagRecords` to Python. (`mssql_python/pybind/ddbc_bindings.cpp`) **Testing and specification compliance:** * Added a comprehensive test suite to verify message capturing, clearing, preservation across fetches, handling of multiple messages, message formatting, warning capture, manual clearing, and error scenarios, ensuring compliance with DBAPI and robust behavior. (`tests/test_004_cursor.py`) **Other cursor improvements:** * Refactored the `skip` method to validate arguments more strictly, clear messages before skipping, and improve error handling and documentation. (`mssql_python/cursor.py`) These changes significantly improve the usability and correctness of message handling in the driver, making it easier for users to access and manage SQL Server informational and warning messages in Python applications. --------- Co-authored-by: Jahnvi Thakkar <jathakkar@microsoft.com>
1 parent d9eb8d1 commit 3d8304e

File tree

5 files changed

+950
-40
lines changed

5 files changed

+950
-40
lines changed

mssql_python/cursor.py

Lines changed: 198 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __init__(self, connection) -> None:
8181
self._has_result_set = False # Track if we have an active result set
8282
self._skip_increment_for_next_fetch = False # Track if we need to skip incrementing the row index
8383

84+
self.messages = [] # Store diagnostic messages
85+
8486
def _is_unicode_string(self, param):
8587
"""
8688
Check if a string contains non-ASCII characters.
@@ -453,6 +455,9 @@ def close(self) -> None:
453455
if self.closed:
454456
raise Exception("Cursor is already closed.")
455457

458+
# Clear messages per DBAPI
459+
self.messages = []
460+
456461
if self.hstmt:
457462
self.hstmt.free()
458463
self.hstmt = None
@@ -698,6 +703,9 @@ def execute(
698703
if reset_cursor:
699704
self._reset_cursor()
700705

706+
# Clear any previous messages
707+
self.messages = []
708+
701709
param_info = ddbc_bindings.ParamInfo
702710
parameters_type = []
703711

@@ -745,7 +753,14 @@ def execute(
745753
self.is_stmt_prepared,
746754
use_prepare,
747755
)
756+
757+
# Check for errors but don't raise exceptions for info/warning messages
748758
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
759+
760+
# Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.)
761+
if self.hstmt:
762+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
763+
749764
self.last_executed_stmt = operation
750765

751766
# Update rowcount after execution
@@ -827,7 +842,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
827842
"""
828843
self._check_closed()
829844
self._reset_cursor()
830-
845+
846+
# Clear any previous messages
847+
self.messages = []
848+
831849
if not seq_of_parameters:
832850
self.rowcount = 0
833851
return
@@ -859,6 +877,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
859877
)
860878
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
861879

880+
# Capture any diagnostic messages after execution
881+
if self.hstmt:
882+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
883+
862884
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
863885
self.last_executed_stmt = operation
864886
self._initialize_description()
@@ -884,6 +906,9 @@ def fetchone(self) -> Union[None, Row]:
884906
try:
885907
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)
886908

909+
if self.hstmt:
910+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
911+
887912
if ret == ddbc_sql_const.SQL_NO_DATA.value:
888913
return None
889914

@@ -894,8 +919,9 @@ def fetchone(self) -> Union[None, Row]:
894919
else:
895920
self._increment_rownumber()
896921

897-
# Create and return a Row object
898-
return Row(row_data, self.description)
922+
# Create and return a Row object, passing column name map if available
923+
column_map = getattr(self, '_column_name_map', None)
924+
return Row(row_data, self.description, column_map)
899925
except Exception as e:
900926
# On error, don't increment rownumber - rethrow the error
901927
raise e
@@ -924,6 +950,10 @@ def fetchmany(self, size: int = None) -> List[Row]:
924950
rows_data = []
925951
try:
926952
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
953+
954+
if self.hstmt:
955+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
956+
927957

928958
# Update rownumber for the number of rows actually fetched
929959
if rows_data and self._has_result_set:
@@ -932,7 +962,8 @@ def fetchmany(self, size: int = None) -> List[Row]:
932962
self._rownumber = self._next_row_index - 1
933963

934964
# Convert raw data to Row objects
935-
return [Row(row_data, self.description) for row_data in rows_data]
965+
column_map = getattr(self, '_column_name_map', None)
966+
return [Row(row_data, self.description, column_map) for row_data in rows_data]
936967
except Exception as e:
937968
# On error, don't increment rownumber - rethrow the error
938969
raise e
@@ -952,14 +983,19 @@ def fetchall(self) -> List[Row]:
952983
rows_data = []
953984
try:
954985
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
986+
987+
if self.hstmt:
988+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
989+
955990

956991
# Update rownumber for the number of rows actually fetched
957992
if rows_data and self._has_result_set:
958993
self._next_row_index += len(rows_data)
959994
self._rownumber = self._next_row_index - 1
960995

961996
# Convert raw data to Row objects
962-
return [Row(row_data, self.description) for row_data in rows_data]
997+
column_map = getattr(self, '_column_name_map', None)
998+
return [Row(row_data, self.description, column_map) for row_data in rows_data]
963999
except Exception as e:
9641000
# On error, don't increment rownumber - rethrow the error
9651001
raise e
@@ -976,6 +1012,9 @@ def nextset(self) -> Union[bool, None]:
9761012
"""
9771013
self._check_closed() # Check if the cursor is closed
9781014

1015+
# Clear messages per DBAPI
1016+
self.messages = []
1017+
9791018
# Skip to the next result set
9801019
ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt)
9811020
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
@@ -1056,6 +1095,9 @@ def commit(self):
10561095
"""
10571096
self._check_closed() # Check if the cursor is closed
10581097

1098+
# Clear messages per DBAPI
1099+
self.messages = []
1100+
10591101
# Delegate to the connection's commit method
10601102
self._connection.commit()
10611103

@@ -1082,6 +1124,9 @@ def rollback(self):
10821124
"""
10831125
self._check_closed() # Check if the cursor is closed
10841126

1127+
# Clear messages per DBAPI
1128+
self.messages = []
1129+
10851130
# Delegate to the connection's rollback method
10861131
self._connection.rollback()
10871132

@@ -1107,6 +1152,10 @@ def scroll(self, value: int, mode: str = 'relative') -> None:
11071152
- absolute(k>0): next fetch returns row index k (0-based); rownumber == k after scroll.
11081153
"""
11091154
self._check_closed()
1155+
1156+
# Clear messages per DBAPI
1157+
self.messages = []
1158+
11101159
if mode not in ('relative', 'absolute'):
11111160
raise ProgrammingError("Invalid scroll mode",
11121161
f"mode must be 'relative' or 'absolute', got '{mode}'")
@@ -1179,29 +1228,156 @@ def scroll(self, value: int, mode: str = 'relative') -> None:
11791228

11801229
def skip(self, count: int) -> None:
11811230
"""
1182-
Skip the next 'count' records in the query result set.
1183-
1184-
This is a convenience method that advances the cursor by 'count'
1185-
positions without returning the skipped rows.
1231+
Skip the next count records in the query result set.
11861232
11871233
Args:
1188-
count: Number of records to skip. Must be non-negative.
1189-
1190-
Returns:
1191-
None
1234+
count: Number of records to skip.
11921235
11931236
Raises:
1194-
ProgrammingError: If the cursor is closed or no result set is available.
1195-
NotSupportedError: If count is negative (backward scrolling not supported).
11961237
IndexError: If attempting to skip past the end of the result set.
1197-
1198-
Note:
1199-
For convenience, skip(0) is accepted and will do nothing.
1238+
ProgrammingError: If count is not an integer.
1239+
NotSupportedError: If attempting to skip backwards.
12001240
"""
1241+
from mssql_python.exceptions import ProgrammingError, NotSupportedError
1242+
12011243
self._check_closed()
12021244

1203-
if count == 0: # Skip 0 is a no-op
1204-
return
1245+
# Clear messages
1246+
self.messages = []
1247+
1248+
# Simply delegate to the scroll method with 'relative' mode
1249+
self.scroll(count, 'relative')
1250+
1251+
def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None,
1252+
table_type=None, search_escape=None):
1253+
"""
1254+
Execute SQLTables ODBC function to retrieve table metadata.
1255+
1256+
Args:
1257+
stmt_handle: ODBC statement handle
1258+
catalog_name: The catalog name pattern
1259+
schema_name: The schema name pattern
1260+
table_name: The table name pattern
1261+
table_type: The table type filter
1262+
search_escape: The escape character for pattern matching
1263+
"""
1264+
# Convert None values to empty strings for ODBC
1265+
catalog = "" if catalog_name is None else catalog_name
1266+
schema = "" if schema_name is None else schema_name
1267+
table = "" if table_name is None else table_name
1268+
types = "" if table_type is None else table_type
1269+
1270+
# Call the ODBC SQLTables function
1271+
retcode = ddbc_bindings.DDBCSQLTables(
1272+
stmt_handle,
1273+
catalog,
1274+
schema,
1275+
table,
1276+
types
1277+
)
1278+
1279+
# Check return code and handle errors
1280+
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode)
1281+
1282+
# Capture any diagnostic messages
1283+
if stmt_handle:
1284+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle))
1285+
1286+
def tables(self, table=None, catalog=None, schema=None, tableType=None):
1287+
"""
1288+
Returns information about tables in the database that match the given criteria using
1289+
the SQLTables ODBC function.
1290+
1291+
Args:
1292+
table (str, optional): The table name pattern. Default is None (all tables).
1293+
catalog (str, optional): The catalog name. Default is None.
1294+
schema (str, optional): The schema name pattern. Default is None.
1295+
tableType (str or list, optional): The table type filter. Default is None.
1296+
Example: "TABLE" or ["TABLE", "VIEW"]
1297+
1298+
Returns:
1299+
list: A list of Row objects containing table information with these columns:
1300+
- table_cat: Catalog name
1301+
- table_schem: Schema name
1302+
- table_name: Table name
1303+
- table_type: Table type (e.g., "TABLE", "VIEW")
1304+
- remarks: Comments about the table
1305+
1306+
Notes:
1307+
This method only processes the standard five columns as defined in the ODBC
1308+
specification. Any additional columns that might be returned by specific ODBC
1309+
drivers are not included in the result set.
1310+
1311+
Example:
1312+
# Get all tables in the database
1313+
tables = cursor.tables()
12051314
1206-
# Use existing scroll method with relative mode
1207-
self.scroll(count, 'relative')
1315+
# Get all tables in schema 'dbo'
1316+
tables = cursor.tables(schema='dbo')
1317+
1318+
# Get table named 'Customers'
1319+
tables = cursor.tables(table='Customers')
1320+
1321+
# Get all views
1322+
tables = cursor.tables(tableType='VIEW')
1323+
"""
1324+
self._check_closed()
1325+
1326+
# Clear messages
1327+
self.messages = []
1328+
1329+
# Always reset the cursor first to ensure clean state
1330+
self._reset_cursor()
1331+
1332+
# Format table_type parameter - SQLTables expects comma-separated string
1333+
table_type_str = None
1334+
if tableType is not None:
1335+
if isinstance(tableType, (list, tuple)):
1336+
table_type_str = ",".join(tableType)
1337+
else:
1338+
table_type_str = str(tableType)
1339+
1340+
# Call SQLTables via the helper method
1341+
self._execute_tables(
1342+
self.hstmt,
1343+
catalog_name=catalog,
1344+
schema_name=schema,
1345+
table_name=table,
1346+
table_type=table_type_str
1347+
)
1348+
1349+
# Initialize description from column metadata
1350+
column_metadata = []
1351+
try:
1352+
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
1353+
self._initialize_description(column_metadata)
1354+
except Exception:
1355+
# If describe fails, create a manual description for the standard columns
1356+
column_types = [str, str, str, str, str]
1357+
self.description = [
1358+
("table_cat", column_types[0], None, 128, 128, 0, True),
1359+
("table_schem", column_types[1], None, 128, 128, 0, True),
1360+
("table_name", column_types[2], None, 128, 128, 0, False),
1361+
("table_type", column_types[3], None, 128, 128, 0, False),
1362+
("remarks", column_types[4], None, 254, 254, 0, True)
1363+
]
1364+
1365+
# Define column names in ODBC standard order
1366+
column_names = [
1367+
"table_cat", "table_schem", "table_name", "table_type", "remarks"
1368+
]
1369+
1370+
# Fetch all rows
1371+
rows_data = []
1372+
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
1373+
1374+
# Create a column map for attribute access
1375+
column_map = {name: i for i, name in enumerate(column_names)}
1376+
1377+
# Create Row objects with the column map
1378+
result_rows = []
1379+
for row_data in rows_data:
1380+
row = Row(row_data, self.description, column_map)
1381+
result_rows.append(row)
1382+
1383+
return result_rows

0 commit comments

Comments
 (0)