Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

logger = get_logger()


class Cursor:
"""
Represents a database cursor, which is used to manage the context of a fetch operation.
Expand Down Expand Up @@ -618,37 +617,65 @@ def execute(
# Initialize description after execution
self._initialize_description()

def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
"""
Convert list of rows (row-wise) into list of columns (column-wise),
for array binding via ODBC.
"""
if not seq_of_parameters:
return []

num_params = len(seq_of_parameters[0])
columnwise = [[] for _ in range(num_params)]
for row in seq_of_parameters:
if len(row) != num_params:
raise ValueError("Inconsistent parameter row size in executemany()")
for i, val in enumerate(row):
columnwise[i].append(val)
return columnwise

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Prepare a database operation and execute it against all parameter sequences.
This version uses column-wise parameter binding and a single batched SQLExecute().
"""
self._check_closed()
self._reset_cursor()

Args:
operation: SQL query or command.
seq_of_parameters: Sequence of sequences or mappings of parameters.
if not seq_of_parameters:
self.rowcount = 0
return

Raises:
Error: If the operation fails.
"""
self._check_closed() # Check if the cursor is closed
param_info = ddbc_bindings.ParamInfo
param_count = len(seq_of_parameters[0])
parameters_type = []

self._reset_cursor()
for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = column[0]
if isinstance(sample_value, str):
sample_value = max(column, key=lambda s: len(str(s)) if s is not None else 0)
elif isinstance(sample_value, decimal.Decimal):
sample_value = max(column, key=lambda d: len(d.as_tuple().digits) if d is not None else 0)
param = sample_value
dummy_row = list(seq_of_parameters[0])
parameters_type.append(self._create_parameter_types_list(param, param_info, dummy_row, col_index))

columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)

# Execute batched statement
ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
columnwise_params,
parameters_type,
len(seq_of_parameters)
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

first_execution = True
total_rowcount = 0
for parameters in seq_of_parameters:
parameters = list(parameters)
if ENABLE_LOGGING:
logger.info("Executing query with parameters: %s", parameters)
prepare_stmt = first_execution
first_execution = False
self.execute(
operation, parameters, use_prepare=prepare_stmt, reset_cursor=False
)
if self.rowcount != -1:
total_rowcount += self.rowcount
else:
total_rowcount = -1
self.rowcount = total_rowcount
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()

def fetchone(self) -> Union[None, Row]:
"""
Expand Down
177 changes: 177 additions & 0 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@
return static_cast<ParamType*>(paramBuffers.back().get());
}

template <typename ParamType>
ParamType* AllocateParamBufferArray(std::vector<std::shared_ptr<void>>& paramBuffers,
size_t count) {
std::shared_ptr<ParamType> buffer(new ParamType[count], std::default_delete<ParamType[]>());
ParamType* raw = buffer.get();
paramBuffers.push_back(buffer);
return raw;
}

std::string DescribeChar(unsigned char ch) {
if (ch >= 32 && ch <= 126) {
return std::string("'") + static_cast<char>(ch) + "'";
Expand Down Expand Up @@ -933,6 +942,173 @@
}
}

SQLRETURN BindParameterArray(SQLHANDLE hStmt,
const py::list& columnwise_params,
const std::vector<ParamInfo>& paramInfos,
size_t paramSetSize,
std::vector<std::shared_ptr<void>>& paramBuffers) {
LOG("Starting column-wise parameter array binding. paramSetSize: {}, paramCount: {}", paramSetSize, columnwise_params.size());
for (int paramIndex = 0; paramIndex < columnwise_params.size(); ++paramIndex) {
const py::list& columnValues = columnwise_params[paramIndex].cast<py::list>();
const ParamInfo& info = paramInfos[paramIndex];
if (columnValues.size() != paramSetSize) {
ThrowStdException("Column " + std::to_string(paramIndex) + " has mismatched size.");
}
void* dataPtr = nullptr;
SQLLEN* strLenOrIndArray = nullptr;
SQLLEN bufferLength = 0;

switch (info.paramCType) {
case SQL_C_LONG: {
int* dataArray = AllocateParamBufferArray<int>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
dataArray[i] = columnValues[i].cast<int>();
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
break;
}
case SQL_C_DOUBLE: {
double* dataArray = AllocateParamBufferArray<double>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
dataArray[i] = columnValues[i].cast<double>();
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
break;
}
case SQL_C_WCHAR: {
SQLWCHAR* wcharArray = AllocateParamBufferArray<SQLWCHAR>(paramBuffers, paramSetSize * (info.columnSize + 1));
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
strLenOrIndArray[i] = SQL_NULL_DATA;
std::memset(wcharArray + i * (info.columnSize + 1), 0, (info.columnSize + 1) * sizeof(SQLWCHAR));
continue;
}
std::wstring wstr = columnValues[i].cast<std::wstring>();
if (wstr.length() > info.columnSize) {
std::string offending = WideToUTF8(wstr);
ThrowStdException("String too long at param " + std::to_string(paramIndex) +
", value: " + offending +
", len: " + std::to_string(wstr.length()) +
" > columnSize: " + std::to_string(info.columnSize));
}
std::memcpy(wcharArray + i * (info.columnSize + 1), wstr.c_str(), (wstr.length() + 1) * sizeof(SQLWCHAR));
strLenOrIndArray[i] = SQL_NTS;
}
dataPtr = wcharArray;
bufferLength = (info.columnSize + 1) * sizeof(SQLWCHAR);
break;
}
case SQL_C_TINYINT:
case SQL_C_UTINYINT: {
unsigned char* dataArray = AllocateParamBufferArray<unsigned char>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
int intVal = columnValues[i].cast<int>();
if (intVal < 0 || intVal > 255) {
ThrowStdException("UTINYINT value out of range at rowIndex " + std::to_string(i));
}
dataArray[i] = static_cast<unsigned char>(intVal);
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
bufferLength = sizeof(unsigned char);
break;
}
case SQL_C_SHORT: {
short* dataArray = AllocateParamBufferArray<short>(paramBuffers, paramSetSize);
for (size_t i = 0; i < paramSetSize; ++i) {
if (columnValues[i].is_none()) {
if (!strLenOrIndArray)
strLenOrIndArray = AllocateParamBufferArray<SQLLEN>(paramBuffers, paramSetSize);
dataArray[i] = 0;
strLenOrIndArray[i] = SQL_NULL_DATA;
} else {
int intVal = columnValues[i].cast<int>();
if (intVal < std::numeric_limits<short>::min() ||
intVal > std::numeric_limits<short>::max()) {
ThrowStdException("SHORT value out of range at rowIndex " + std::to_string(i));
}
dataArray[i] = static_cast<short>(intVal);
if (strLenOrIndArray) strLenOrIndArray[i] = 0;
}
}
dataPtr = dataArray;
bufferLength = sizeof(short);
break;
}
default: {
ThrowStdException("BindParameterArray: Unsupported C type: " + std::to_string(info.paramCType));
}
}

RETCODE rc = SQLBindParameter_ptr(
hStmt,
static_cast<SQLUSMALLINT>(paramIndex + 1),
static_cast<SQLUSMALLINT>(info.inputOutputType),
static_cast<SQLSMALLINT>(info.paramCType),
static_cast<SQLSMALLINT>(info.paramSQLType),
info.columnSize,
info.decimalDigits,
dataPtr,
bufferLength,
strLenOrIndArray
);
if (!SQL_SUCCEEDED(rc)) {
LOG("Failed to bind array param {}", paramIndex);
return rc;
}
}
LOG("Finished column-wise parameter array binding.");
return SQL_SUCCESS;
}

SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle,
const std::wstring& query,
const py::list& columnwise_params,
const std::vector<ParamInfo>& paramInfos,
size_t paramSetSize) {
SQLHANDLE hStmt = statementHandle->get();
SQLWCHAR* queryPtr;
#if defined(__APPLE__) || defined(__linux__)
std::vector<SQLWCHAR> queryBuffer = WStringToSQLWCHAR(query);
queryPtr = queryBuffer.data();
#else
queryPtr = const_cast<SQLWCHAR*>(query.c_str());
#endif
RETCODE rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS);
if (!SQL_SUCCEEDED(rc)) return rc;
std::vector<std::shared_ptr<void>> paramBuffers;
rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER)paramSetSize, 0);
if (!SQL_SUCCEEDED(rc)) return rc;
rc = SQLExecute_ptr(hStmt);
return rc;
}

// Wrap SQLNumResultCols
SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) {
LOG("Get number of columns in result set");
Expand Down Expand Up @@ -2112,6 +2288,7 @@
m.def("close_pooling", []() {ConnectionPoolManager::getInstance().closePools();});
m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly");
m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements");
m.def("SQLExecuteMany", &SQLExecuteMany_wrap, "Execute statement with multiple parameter sets");
m.def("DDBCSQLRowCount", &SQLRowCount_wrap,
"Get the number of rows affected by the last statement");
m.def("DDBCSQLFetch", &SQLFetch_wrap, "Fetch the next row from the result set");
Expand Down
Loading