Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _map_sql_type(self, param, parameters_list, i):
"""
if param is None:
return (
ddbc_sql_const.SQL_VARCHAR.value, # TODO: Add SQLDescribeParam to get correct type
ddbc_sql_const.SQL_VARCHAR.value,
ddbc_sql_const.SQL_C_DEFAULT.value,
1,
0,
Expand Down
43 changes: 37 additions & 6 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ SQLParamDataFunc SQLParamData_ptr = nullptr;
SQLPutDataFunc SQLPutData_ptr = nullptr;
SQLTablesFunc SQLTables_ptr = nullptr;

SQLDescribeParamFunc SQLDescribeParam_ptr = nullptr;

namespace {

const char* GetSqlCTypeAsString(const SQLSMALLINT cType) {
Expand Down Expand Up @@ -212,12 +214,12 @@ std::string DescribeChar(unsigned char ch) {
// Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with
// appropriate arguments
SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
const std::vector<ParamInfo>& paramInfos,
std::vector<ParamInfo>& paramInfos,
std::vector<std::shared_ptr<void>>& paramBuffers) {
LOG("Starting parameter binding. Number of parameters: {}", params.size());
for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) {
const auto& param = params[paramIndex];
const ParamInfo& paramInfo = paramInfos[paramIndex];
ParamInfo& paramInfo = paramInfos[paramIndex];
LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType);
void* dataPtr = nullptr;
SQLLEN bufferLength = 0;
Expand Down Expand Up @@ -283,11 +285,37 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
if (!py::isinstance<py::none>(param)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
// TODO: This wont work for None values added to BINARY/VARBINARY columns. None values
// of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY
SQLSMALLINT sqlType = paramInfo.paramSQLType;
SQLULEN columnSize = paramInfo.columnSize;
SQLSMALLINT decimalDigits = paramInfo.decimalDigits;
if (sqlType == SQL_UNKNOWN_TYPE) {
SQLSMALLINT describedType;
SQLULEN describedSize;
SQLSMALLINT describedDigits;
SQLSMALLINT nullable;
RETCODE rc = SQLDescribeParam_ptr(
hStmt,
static_cast<SQLUSMALLINT>(paramIndex + 1),
&describedType,
&describedSize,
&describedDigits,
&nullable
);
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLDescribeParam failed for parameter {} with error code {}", paramIndex, rc);
return rc;
}
sqlType = describedType;
columnSize = describedSize;
decimalDigits = describedDigits;
}
dataPtr = nullptr;
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = SQL_NULL_DATA;
bufferLength = 0;
paramInfo.paramSQLType = sqlType;
paramInfo.columnSize = columnSize;
paramInfo.decimalDigits = decimalDigits;
break;
}
case SQL_C_STINYINT:
Expand Down Expand Up @@ -767,6 +795,8 @@ DriverHandle LoadDriverOrThrowException() {
SQLPutData_ptr = GetFunctionPointer<SQLPutDataFunc>(handle, "SQLPutData");
SQLTables_ptr = GetFunctionPointer<SQLTablesFunc>(handle, "SQLTablesW");

SQLDescribeParam_ptr = GetFunctionPointer<SQLDescribeParamFunc>(handle, "SQLDescribeParam");

bool success =
SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr &&
SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr &&
Expand All @@ -777,7 +807,8 @@ DriverHandle LoadDriverOrThrowException() {
SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr &&
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr &&
SQLPutData_ptr && SQLTables_ptr;
SQLPutData_ptr && SQLTables_ptr &&
SQLDescribeParam_ptr;

if (!success) {
ThrowStdException("Failed to load required function pointers from driver.");
Expand Down Expand Up @@ -1072,7 +1103,7 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle,
// be prepared in a previous call.
SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
const std::wstring& query /* TODO: Use SQLTCHAR? */,
const py::list& params, const std::vector<ParamInfo>& paramInfos,
const py::list& params, std::vector<ParamInfo>& paramInfos,
py::list& isStmtPrepared, const bool usePrepare = true) {
LOG("Execute SQL Query - {}", query.c_str());
if (!SQLPrepare_ptr) {
Expand Down
4 changes: 4 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT);
typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*,
SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*);

typedef SQLRETURN (SQL_API* SQLDescribeParamFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, SQLSMALLINT*);

// DAE APIs
typedef SQLRETURN (SQL_API* SQLParamDataFunc)(SQLHSTMT, SQLPOINTER*);
typedef SQLRETURN (SQL_API* SQLPutDataFunc)(SQLHSTMT, SQLPOINTER, SQLLEN);
Expand Down Expand Up @@ -257,6 +259,8 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr;
// Diagnostic APIs
extern SQLGetDiagRecFunc SQLGetDiagRec_ptr;

extern SQLDescribeParamFunc SQLDescribeParam_ptr;

// DAE APIs
extern SQLParamDataFunc SQLParamData_ptr;
extern SQLPutDataFunc SQLPutData_ptr;
Expand Down
Loading