Skip to content

Commit 42785d9

Browse files
authored
Merge branch 'main' into bewithgaurav/fix_blank_columns
2 parents 2255e50 + 83ab8ea commit 42785d9

File tree

5 files changed

+297
-27
lines changed

5 files changed

+297
-27
lines changed

eng/pipelines/pr-validation-pipeline.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ jobs:
9292

9393
- script: |
9494
brew update
95+
# Uninstall existing CMake to avoid tap conflicts
96+
brew uninstall cmake --ignore-dependencies || echo "CMake not installed or already removed"
97+
# Install CMake from homebrew/core
9598
brew install cmake
9699
displayName: 'Install CMake'
97100

mssql_python/cursor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _map_sql_type(self, param, parameters_list, i):
231231
"""
232232
if param is None:
233233
return (
234-
ddbc_sql_const.SQL_VARCHAR.value, # TODO: Add SQLDescribeParam to get correct type
234+
ddbc_sql_const.SQL_VARCHAR.value,
235235
ddbc_sql_const.SQL_C_DEFAULT.value,
236236
1,
237237
0,
@@ -342,12 +342,15 @@ def _map_sql_type(self, param, parameters_list, i):
342342

343343
# String mapping logic here
344344
is_unicode = self._is_unicode_string(param)
345-
if len(param) > MAX_INLINE_CHAR: # Long strings
345+
346+
# Computes UTF-16 code units (handles surrogate pairs)
347+
utf16_len = sum(2 if ord(c) > 0xFFFF else 1 for c in param)
348+
if utf16_len > MAX_INLINE_CHAR: # Long strings -> DAE
346349
if is_unicode:
347350
return (
348351
ddbc_sql_const.SQL_WLONGVARCHAR.value,
349352
ddbc_sql_const.SQL_C_WCHAR.value,
350-
len(param),
353+
utf16_len,
351354
0,
352355
True,
353356
)
@@ -358,8 +361,9 @@ def _map_sql_type(self, param, parameters_list, i):
358361
0,
359362
True,
360363
)
361-
if is_unicode: # Short Unicode strings
362-
utf16_len = len(param.encode("utf-16-le")) // 2
364+
365+
# Short strings
366+
if is_unicode:
363367
return (
364368
ddbc_sql_const.SQL_WVARCHAR.value,
365369
ddbc_sql_const.SQL_C_WCHAR.value,
@@ -374,7 +378,7 @@ def _map_sql_type(self, param, parameters_list, i):
374378
0,
375379
False,
376380
)
377-
381+
378382
if isinstance(param, bytes):
379383
if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays
380384
return (

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ SQLParamDataFunc SQLParamData_ptr = nullptr;
140140
SQLPutDataFunc SQLPutData_ptr = nullptr;
141141
SQLTablesFunc SQLTables_ptr = nullptr;
142142

143+
SQLDescribeParamFunc SQLDescribeParam_ptr = nullptr;
144+
143145
namespace {
144146

145147
const char* GetSqlCTypeAsString(const SQLSMALLINT cType) {
@@ -212,20 +214,40 @@ std::string DescribeChar(unsigned char ch) {
212214
// Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with
213215
// appropriate arguments
214216
SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
215-
const std::vector<ParamInfo>& paramInfos,
217+
std::vector<ParamInfo>& paramInfos,
216218
std::vector<std::shared_ptr<void>>& paramBuffers) {
217219
LOG("Starting parameter binding. Number of parameters: {}", params.size());
218220
for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) {
219221
const auto& param = params[paramIndex];
220-
const ParamInfo& paramInfo = paramInfos[paramIndex];
222+
ParamInfo& paramInfo = paramInfos[paramIndex];
221223
LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType);
222224
void* dataPtr = nullptr;
223225
SQLLEN bufferLength = 0;
224226
SQLLEN* strLenOrIndPtr = nullptr;
225227

226228
// TODO: Add more data types like money, guid, interval, TVPs etc.
227229
switch (paramInfo.paramCType) {
228-
case SQL_C_CHAR:
230+
case SQL_C_CHAR: {
231+
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
232+
!py::isinstance<py::bytes>(param)) {
233+
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
234+
}
235+
if (paramInfo.isDAE) {
236+
LOG("Parameter[{}] is marked for DAE streaming", paramIndex);
237+
dataPtr = const_cast<void*>(reinterpret_cast<const void*>(&paramInfos[paramIndex]));
238+
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
239+
*strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0);
240+
bufferLength = 0;
241+
} else {
242+
std::string* strParam =
243+
AllocateParamBuffer<std::string>(paramBuffers, param.cast<std::string>());
244+
dataPtr = const_cast<void*>(static_cast<const void*>(strParam->c_str()));
245+
bufferLength = strParam->size() + 1;
246+
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
247+
*strLenOrIndPtr = SQL_NTS;
248+
}
249+
break;
250+
}
229251
case SQL_C_BINARY: {
230252
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
231253
!py::isinstance<py::bytes>(param)) {
@@ -283,11 +305,37 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
283305
if (!py::isinstance<py::none>(param)) {
284306
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
285307
}
286-
// TODO: This wont work for None values added to BINARY/VARBINARY columns. None values
287-
// of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY
308+
SQLSMALLINT sqlType = paramInfo.paramSQLType;
309+
SQLULEN columnSize = paramInfo.columnSize;
310+
SQLSMALLINT decimalDigits = paramInfo.decimalDigits;
311+
if (sqlType == SQL_UNKNOWN_TYPE) {
312+
SQLSMALLINT describedType;
313+
SQLULEN describedSize;
314+
SQLSMALLINT describedDigits;
315+
SQLSMALLINT nullable;
316+
RETCODE rc = SQLDescribeParam_ptr(
317+
hStmt,
318+
static_cast<SQLUSMALLINT>(paramIndex + 1),
319+
&describedType,
320+
&describedSize,
321+
&describedDigits,
322+
&nullable
323+
);
324+
if (!SQL_SUCCEEDED(rc)) {
325+
LOG("SQLDescribeParam failed for parameter {} with error code {}", paramIndex, rc);
326+
return rc;
327+
}
328+
sqlType = describedType;
329+
columnSize = describedSize;
330+
decimalDigits = describedDigits;
331+
}
288332
dataPtr = nullptr;
289333
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
290334
*strLenOrIndPtr = SQL_NULL_DATA;
335+
bufferLength = 0;
336+
paramInfo.paramSQLType = sqlType;
337+
paramInfo.columnSize = columnSize;
338+
paramInfo.decimalDigits = decimalDigits;
291339
break;
292340
}
293341
case SQL_C_STINYINT:
@@ -767,6 +815,8 @@ DriverHandle LoadDriverOrThrowException() {
767815
SQLPutData_ptr = GetFunctionPointer<SQLPutDataFunc>(handle, "SQLPutData");
768816
SQLTables_ptr = GetFunctionPointer<SQLTablesFunc>(handle, "SQLTablesW");
769817

818+
SQLDescribeParam_ptr = GetFunctionPointer<SQLDescribeParamFunc>(handle, "SQLDescribeParam");
819+
770820
bool success =
771821
SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr &&
772822
SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr &&
@@ -777,7 +827,8 @@ DriverHandle LoadDriverOrThrowException() {
777827
SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr &&
778828
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
779829
SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr &&
780-
SQLPutData_ptr && SQLTables_ptr;
830+
SQLPutData_ptr && SQLTables_ptr &&
831+
SQLDescribeParam_ptr;
781832

782833
if (!success) {
783834
ThrowStdException("Failed to load required function pointers from driver.");
@@ -1072,7 +1123,7 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle,
10721123
// be prepared in a previous call.
10731124
SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
10741125
const std::wstring& query /* TODO: Use SQLTCHAR? */,
1075-
const py::list& params, const std::vector<ParamInfo>& paramInfos,
1126+
const py::list& params, std::vector<ParamInfo>& paramInfos,
10761127
py::list& isStmtPrepared, const bool usePrepare = true) {
10771128
LOG("Execute SQL Query - {}", query.c_str());
10781129
if (!SQLPrepare_ptr) {
@@ -1172,23 +1223,51 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
11721223
continue;
11731224
}
11741225
if (py::isinstance<py::str>(pyObj)) {
1175-
std::wstring wstr = pyObj.cast<std::wstring>();
1226+
if (matchedInfo->paramCType == SQL_C_WCHAR) {
1227+
std::wstring wstr = pyObj.cast<std::wstring>();
1228+
const SQLWCHAR* dataPtr = nullptr;
1229+
size_t totalChars = 0;
11761230
#if defined(__APPLE__) || defined(__linux__)
1177-
auto utf16Buf = WStringToSQLWCHAR(wstr);
1178-
const char* dataPtr = reinterpret_cast<const char*>(utf16Buf.data());
1179-
size_t totalBytes = (utf16Buf.size() - 1) * sizeof(SQLWCHAR);
1231+
std::vector<SQLWCHAR> sqlwStr = WStringToSQLWCHAR(wstr);
1232+
totalChars = sqlwStr.size() - 1;
1233+
dataPtr = sqlwStr.data();
11801234
#else
1181-
const char* dataPtr = reinterpret_cast<const char*>(wstr.data());
1182-
size_t totalBytes = wstr.size() * sizeof(wchar_t);
1235+
dataPtr = wstr.c_str();
1236+
totalChars = wstr.size();
11831237
#endif
1184-
const size_t chunkSize = DAE_CHUNK_SIZE;
1185-
for (size_t offset = 0; offset < totalBytes; offset += chunkSize) {
1186-
size_t len = std::min(chunkSize, totalBytes - offset);
1187-
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
1188-
if (!SQL_SUCCEEDED(rc)) {
1189-
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
1190-
return rc;
1238+
size_t offset = 0;
1239+
size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR);
1240+
while (offset < totalChars) {
1241+
size_t len = std::min(chunkChars, totalChars - offset);
1242+
size_t lenBytes = len * sizeof(SQLWCHAR);
1243+
if (lenBytes > static_cast<size_t>(std::numeric_limits<SQLLEN>::max())) {
1244+
ThrowStdException("Chunk size exceeds maximum allowed by SQLLEN");
1245+
}
1246+
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(lenBytes));
1247+
if (!SQL_SUCCEEDED(rc)) {
1248+
LOG("SQLPutData failed at offset {} of {}", offset, totalChars);
1249+
return rc;
1250+
}
1251+
offset += len;
11911252
}
1253+
} else if (matchedInfo->paramCType == SQL_C_CHAR) {
1254+
std::string s = pyObj.cast<std::string>();
1255+
size_t totalBytes = s.size();
1256+
const char* dataPtr = s.data();
1257+
size_t offset = 0;
1258+
size_t chunkBytes = DAE_CHUNK_SIZE;
1259+
while (offset < totalBytes) {
1260+
size_t len = std::min(chunkBytes, totalBytes - offset);
1261+
1262+
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
1263+
if (!SQL_SUCCEEDED(rc)) {
1264+
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
1265+
return rc;
1266+
}
1267+
offset += len;
1268+
}
1269+
} else {
1270+
ThrowStdException("Unsupported C type for str in DAE");
11921271
}
11931272
} else {
11941273
ThrowStdException("DAE only supported for str or bytes");

mssql_python/pybind/ddbc_bindings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ typedef SQLRETURN (SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT);
211211
typedef SQLRETURN (SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*, SQLINTEGER*,
212212
SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*);
213213

214+
typedef SQLRETURN (SQL_API* SQLDescribeParamFunc)(SQLHSTMT, SQLUSMALLINT, SQLSMALLINT*, SQLULEN*, SQLSMALLINT*, SQLSMALLINT*);
215+
214216
// DAE APIs
215217
typedef SQLRETURN (SQL_API* SQLParamDataFunc)(SQLHSTMT, SQLPOINTER*);
216218
typedef SQLRETURN (SQL_API* SQLPutDataFunc)(SQLHSTMT, SQLPOINTER, SQLLEN);
@@ -257,6 +259,8 @@ extern SQLFreeStmtFunc SQLFreeStmt_ptr;
257259
// Diagnostic APIs
258260
extern SQLGetDiagRecFunc SQLGetDiagRec_ptr;
259261

262+
extern SQLDescribeParamFunc SQLDescribeParam_ptr;
263+
260264
// DAE APIs
261265
extern SQLParamDataFunc SQLParamData_ptr;
262266
extern SQLPutDataFunc SQLPutData_ptr;

0 commit comments

Comments
 (0)