Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(self, connection) -> None:
self._next_row_index = 0 # internal: index of the next row the driver will return (0-based)
self._has_result_set = False # Track if we have an active result set

self.messages = [] # Store diagnostic messages

def _is_unicode_string(self, param):
"""
Check if a string contains non-ASCII characters.
Expand Down Expand Up @@ -452,6 +454,9 @@ def close(self) -> None:
if self.closed:
raise Exception("Cursor is already closed.")

# Clear messages per DBAPI
self.messages = []

if self.hstmt:
self.hstmt.free()
self.hstmt = None
Expand Down Expand Up @@ -695,6 +700,9 @@ def execute(
if reset_cursor:
self._reset_cursor()

# Clear any previous messages
self.messages = []

param_info = ddbc_bindings.ParamInfo
parameters_type = []

Expand Down Expand Up @@ -742,7 +750,14 @@ def execute(
self.is_stmt_prepared,
use_prepare,
)

# Check for errors but don't raise exceptions for info/warning messages
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.)
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.last_executed_stmt = operation

# Update rowcount after execution
Expand Down Expand Up @@ -822,7 +837,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
self._check_closed()
self._reset_cursor()


# Clear any previous messages
self.messages = []

if not seq_of_parameters:
self.rowcount = 0
return
Expand Down Expand Up @@ -854,6 +872,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

# Capture any diagnostic messages after execution
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()
Expand All @@ -877,6 +899,9 @@ def fetchone(self) -> Union[None, Row]:
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

if ret == ddbc_sql_const.SQL_NO_DATA.value:
return None

Expand Down Expand Up @@ -911,6 +936,10 @@ def fetchmany(self, size: int = None) -> List[Row]:
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))


# Update rownumber for the number of rows actually fetched
if rows_data and self._has_result_set:
Expand All @@ -937,6 +966,10 @@ def fetchall(self) -> List[Row]:
rows_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))


# Update rownumber for the number of rows actually fetched
if rows_data and self._has_result_set:
Expand All @@ -961,6 +994,9 @@ def nextset(self) -> Union[bool, None]:
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Skip to the next result set
ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
Expand Down Expand Up @@ -1041,6 +1077,9 @@ def commit(self):
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Delegate to the connection's commit method
self._connection.commit()

Expand All @@ -1067,6 +1106,9 @@ def rollback(self):
"""
self._check_closed() # Check if the cursor is closed

# Clear messages per DBAPI
self.messages = []

# Delegate to the connection's rollback method
self._connection.rollback()

Expand All @@ -1090,6 +1132,10 @@ def scroll(self, value: int, mode: str = 'relative') -> None:
This implementation emulates scrolling for forward-only cursors by consuming rows.
"""
self._check_closed()

# Clear messages per DBAPI
self.messages = []

if mode not in ('relative', 'absolute'):
raise ProgrammingError(
driver_error="Invalid scroll mode",
Expand Down Expand Up @@ -1195,29 +1241,36 @@ def _consume_rows_for_scroll(self, rows_to_consume: int) -> None:

def skip(self, count: int) -> None:
"""
Skip the next 'count' records in the query result set.

This is a convenience method that advances the cursor by 'count'
positions without returning the skipped rows.
Skip the next count records in the query result set.

Args:
count: Number of records to skip. Must be non-negative.

Returns:
None
count: Number of records to skip.

Raises:
ProgrammingError: If the cursor is closed or no result set is available.
NotSupportedError: If count is negative (backward scrolling not supported).
IndexError: If attempting to skip past the end of the result set.

Note:
For convenience, skip(0) is accepted and will do nothing.
ProgrammingError: If count is not an integer.
NotSupportedError: If attempting to skip backwards.
"""
from mssql_python.exceptions import ProgrammingError, NotSupportedError

self._check_closed()

if count == 0: # Skip 0 is a no-op
# Clear messages
self.messages = []

# Validate arguments
if not isinstance(count, int):
raise ProgrammingError("Count must be an integer", "Invalid argument type")

if count < 0:
raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported")

# Skip zero is a no-op
if count == 0:
return

# Use existing scroll method with relative mode
self.scroll(count, 'relative')

# Skip the rows by fetching and discarding
for _ in range(count):
row = self.fetchone()
if row is None:
raise IndexError("Cannot skip beyond the end of the result set")
63 changes: 63 additions & 0 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,65 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET
return errorInfo;
}

py::list SQLGetAllDiagRecords(SqlHandlePtr handle) {
LOG("Retrieving all diagnostic records");
if (!SQLGetDiagRec_ptr) {
LOG("Function pointer not initialized. Loading the driver.");
DriverLoader::getInstance().loadDriver();
}

py::list records;
SQLHANDLE rawHandle = handle->get();
SQLSMALLINT handleType = handle->type();

// Iterate through all available diagnostic records
for (SQLSMALLINT recNumber = 1; ; recNumber++) {
SQLWCHAR sqlState[6] = {0};
SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0};
SQLINTEGER nativeError = 0;
SQLSMALLINT messageLen = 0;

SQLRETURN diagReturn = SQLGetDiagRec_ptr(
handleType, rawHandle, recNumber, sqlState, &nativeError,
message, SQL_MAX_MESSAGE_LENGTH, &messageLen);

if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn))
break;

#if defined(_WIN32)
// On Windows, create a formatted UTF-8 string for state+error
char stateWithError[50];
sprintf(stateWithError, "[%ls] (%d)", sqlState, nativeError);

// Convert wide string message to UTF-8
int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL);
std::vector<char> msgBuffer(msgSize);
WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL);

// Create the tuple with converted strings
records.append(py::make_tuple(
py::str(stateWithError),
py::str(msgBuffer.data())
));
#else
// On Unix, use the SQLWCHARToWString utility and then convert to UTF-8
std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState));
std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen));

// Format the state string
std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")";

// Create the tuple with converted strings
records.append(py::make_tuple(
py::str(stateWithError),
py::str(msgStr)
));
#endif
}

return records;
}

// Wrap SQLExecDirect
SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) {
LOG("Execute SQL query directly - {}", Query.c_str());
Expand Down Expand Up @@ -2553,6 +2612,10 @@ PYBIND11_MODULE(ddbc_bindings, m) {
m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set");
m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle");
m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors");
// Add this to your PYBIND11_MODULE section
m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords,
"Get all diagnostic records for a handle",
py::arg("handle"));

// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
Loading