@@ -140,6 +140,8 @@ SQLParamDataFunc SQLParamData_ptr = nullptr;
140140SQLPutDataFunc SQLPutData_ptr = nullptr ;
141141SQLTablesFunc SQLTables_ptr = nullptr ;
142142
143+ SQLDescribeParamFunc SQLDescribeParam_ptr = nullptr ;
144+
143145namespace {
144146
145147const char * GetSqlCTypeAsString (const SQLSMALLINT cType) {
@@ -212,20 +214,40 @@ std::string DescribeChar(unsigned char ch) {
212214// Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with
213215// appropriate arguments
214216SQLRETURN BindParameters (SQLHANDLE hStmt, const py::list& params,
215- const std::vector<ParamInfo>& paramInfos,
217+ std::vector<ParamInfo>& paramInfos,
216218 std::vector<std::shared_ptr<void >>& paramBuffers) {
217219 LOG (" Starting parameter binding. Number of parameters: {}" , params.size ());
218220 for (int paramIndex = 0 ; paramIndex < params.size (); paramIndex++) {
219221 const auto & param = params[paramIndex];
220- const ParamInfo& paramInfo = paramInfos[paramIndex];
222+ ParamInfo& paramInfo = paramInfos[paramIndex];
221223 LOG (" Binding parameter {} - C Type: {}, SQL Type: {}" , paramIndex, paramInfo.paramCType , paramInfo.paramSQLType );
222224 void * dataPtr = nullptr ;
223225 SQLLEN bufferLength = 0 ;
224226 SQLLEN* strLenOrIndPtr = nullptr ;
225227
226228 // TODO: Add more data types like money, guid, interval, TVPs etc.
227229 switch (paramInfo.paramCType ) {
228- case SQL_C_CHAR:
230+ case SQL_C_CHAR: {
231+ if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
232+ !py::isinstance<py::bytes>(param)) {
233+ ThrowStdException (MakeParamMismatchErrorStr (paramInfo.paramCType , paramIndex));
234+ }
235+ if (paramInfo.isDAE ) {
236+ LOG (" Parameter[{}] is marked for DAE streaming" , paramIndex);
237+ dataPtr = const_cast <void *>(reinterpret_cast <const void *>(¶mInfos[paramIndex]));
238+ strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
239+ *strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC (0 );
240+ bufferLength = 0 ;
241+ } else {
242+ std::string* strParam =
243+ AllocateParamBuffer<std::string>(paramBuffers, param.cast <std::string>());
244+ dataPtr = const_cast <void *>(static_cast <const void *>(strParam->c_str ()));
245+ bufferLength = strParam->size () + 1 ;
246+ strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
247+ *strLenOrIndPtr = SQL_NTS;
248+ }
249+ break ;
250+ }
229251 case SQL_C_BINARY: {
230252 if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
231253 !py::isinstance<py::bytes>(param)) {
@@ -283,11 +305,37 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
283305 if (!py::isinstance<py::none>(param)) {
284306 ThrowStdException (MakeParamMismatchErrorStr (paramInfo.paramCType , paramIndex));
285307 }
286- // TODO: This wont work for None values added to BINARY/VARBINARY columns. None values
287- // of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY
308+ SQLSMALLINT sqlType = paramInfo.paramSQLType ;
309+ SQLULEN columnSize = paramInfo.columnSize ;
310+ SQLSMALLINT decimalDigits = paramInfo.decimalDigits ;
311+ if (sqlType == SQL_UNKNOWN_TYPE) {
312+ SQLSMALLINT describedType;
313+ SQLULEN describedSize;
314+ SQLSMALLINT describedDigits;
315+ SQLSMALLINT nullable;
316+ RETCODE rc = SQLDescribeParam_ptr (
317+ hStmt,
318+ static_cast <SQLUSMALLINT>(paramIndex + 1 ),
319+ &describedType,
320+ &describedSize,
321+ &describedDigits,
322+ &nullable
323+ );
324+ if (!SQL_SUCCEEDED (rc)) {
325+ LOG (" SQLDescribeParam failed for parameter {} with error code {}" , paramIndex, rc);
326+ return rc;
327+ }
328+ sqlType = describedType;
329+ columnSize = describedSize;
330+ decimalDigits = describedDigits;
331+ }
288332 dataPtr = nullptr ;
289333 strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
290334 *strLenOrIndPtr = SQL_NULL_DATA;
335+ bufferLength = 0 ;
336+ paramInfo.paramSQLType = sqlType;
337+ paramInfo.columnSize = columnSize;
338+ paramInfo.decimalDigits = decimalDigits;
291339 break ;
292340 }
293341 case SQL_C_STINYINT:
@@ -767,6 +815,8 @@ DriverHandle LoadDriverOrThrowException() {
767815 SQLPutData_ptr = GetFunctionPointer<SQLPutDataFunc>(handle, " SQLPutData" );
768816 SQLTables_ptr = GetFunctionPointer<SQLTablesFunc>(handle, " SQLTablesW" );
769817
818+ SQLDescribeParam_ptr = GetFunctionPointer<SQLDescribeParamFunc>(handle, " SQLDescribeParam" );
819+
770820 bool success =
771821 SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr &&
772822 SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr &&
@@ -777,7 +827,8 @@ DriverHandle LoadDriverOrThrowException() {
777827 SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr &&
778828 SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
779829 SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr &&
780- SQLPutData_ptr && SQLTables_ptr;
830+ SQLPutData_ptr && SQLTables_ptr &&
831+ SQLDescribeParam_ptr;
781832
782833 if (!success) {
783834 ThrowStdException (" Failed to load required function pointers from driver." );
@@ -1072,7 +1123,7 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle,
10721123// be prepared in a previous call.
10731124SQLRETURN SQLExecute_wrap (const SqlHandlePtr statementHandle,
10741125 const std::wstring& query /* TODO: Use SQLTCHAR? */ ,
1075- const py::list& params, const std::vector<ParamInfo>& paramInfos,
1126+ const py::list& params, std::vector<ParamInfo>& paramInfos,
10761127 py::list& isStmtPrepared, const bool usePrepare = true ) {
10771128 LOG (" Execute SQL Query - {}" , query.c_str ());
10781129 if (!SQLPrepare_ptr) {
@@ -1172,23 +1223,51 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
11721223 continue ;
11731224 }
11741225 if (py::isinstance<py::str>(pyObj)) {
1175- std::wstring wstr = pyObj.cast <std::wstring>();
1226+ if (matchedInfo->paramCType == SQL_C_WCHAR) {
1227+ std::wstring wstr = pyObj.cast <std::wstring>();
1228+ const SQLWCHAR* dataPtr = nullptr ;
1229+ size_t totalChars = 0 ;
11761230#if defined(__APPLE__) || defined(__linux__)
1177- auto utf16Buf = WStringToSQLWCHAR (wstr);
1178- const char * dataPtr = reinterpret_cast < const char *>(utf16Buf. data ()) ;
1179- size_t totalBytes = (utf16Buf. size () - 1 ) * sizeof (SQLWCHAR );
1231+ std::vector<SQLWCHAR> sqlwStr = WStringToSQLWCHAR (wstr);
1232+ totalChars = sqlwStr. size () - 1 ;
1233+ dataPtr = sqlwStr. data ( );
11801234#else
1181- const char * dataPtr = reinterpret_cast < const char *>( wstr.data () );
1182- size_t totalBytes = wstr.size () * sizeof ( wchar_t );
1235+ dataPtr = wstr.c_str ( );
1236+ totalChars = wstr.size ();
11831237#endif
1184- const size_t chunkSize = DAE_CHUNK_SIZE;
1185- for (size_t offset = 0 ; offset < totalBytes; offset += chunkSize) {
1186- size_t len = std::min (chunkSize, totalBytes - offset);
1187- rc = SQLPutData_ptr (hStmt, (SQLPOINTER)(dataPtr + offset), static_cast <SQLLEN>(len));
1188- if (!SQL_SUCCEEDED (rc)) {
1189- LOG (" SQLPutData failed at offset {} of {}" , offset, totalBytes);
1190- return rc;
1238+ size_t offset = 0 ;
1239+ size_t chunkChars = DAE_CHUNK_SIZE / sizeof (SQLWCHAR);
1240+ while (offset < totalChars) {
1241+ size_t len = std::min (chunkChars, totalChars - offset);
1242+ size_t lenBytes = len * sizeof (SQLWCHAR);
1243+ if (lenBytes > static_cast <size_t >(std::numeric_limits<SQLLEN>::max ())) {
1244+ ThrowStdException (" Chunk size exceeds maximum allowed by SQLLEN" );
1245+ }
1246+ rc = SQLPutData_ptr (hStmt, (SQLPOINTER)(dataPtr + offset), static_cast <SQLLEN>(lenBytes));
1247+ if (!SQL_SUCCEEDED (rc)) {
1248+ LOG (" SQLPutData failed at offset {} of {}" , offset, totalChars);
1249+ return rc;
1250+ }
1251+ offset += len;
11911252 }
1253+ } else if (matchedInfo->paramCType == SQL_C_CHAR) {
1254+ std::string s = pyObj.cast <std::string>();
1255+ size_t totalBytes = s.size ();
1256+ const char * dataPtr = s.data ();
1257+ size_t offset = 0 ;
1258+ size_t chunkBytes = DAE_CHUNK_SIZE;
1259+ while (offset < totalBytes) {
1260+ size_t len = std::min (chunkBytes, totalBytes - offset);
1261+
1262+ rc = SQLPutData_ptr (hStmt, (SQLPOINTER)(dataPtr + offset), static_cast <SQLLEN>(len));
1263+ if (!SQL_SUCCEEDED (rc)) {
1264+ LOG (" SQLPutData failed at offset {} of {}" , offset, totalBytes);
1265+ return rc;
1266+ }
1267+ offset += len;
1268+ }
1269+ } else {
1270+ ThrowStdException (" Unsupported C type for str in DAE" );
11921271 }
11931272 } else {
11941273 ThrowStdException (" DAE only supported for str or bytes" );
0 commit comments