diff --git a/src/cursor.cpp b/src/cursor.cpp index b0dace58..f2620594 100644 --- a/src/cursor.cpp +++ b/src/cursor.cpp @@ -684,11 +684,11 @@ static PyObject* execute(Cursor* cur, PyObject* pSql, PyObject* params, bool ski if (ret == SQL_NEED_DATA) { szLastFunction = "SQLPutData"; - if (PyBytes_Check(pInfo->pObject) + if (pInfo->pObject && (PyBytes_Check(pInfo->pObject) #if PY_VERSION_HEX >= 0x02060000 || PyByteArray_Check(pInfo->pObject) #endif - ) + )) { char *(*pGetPtr)(PyObject*); Py_ssize_t (*pGetLen)(PyObject*); @@ -711,7 +711,7 @@ static PyObject* execute(Cursor* cur, PyObject* pSql, PyObject* params, bool ski do { - SQLLEN remaining = min(pInfo->maxlength, cb - offset); + SQLLEN remaining = pInfo->maxlength ? min(pInfo->maxlength, cb - offset) : cb; TRACE("SQLPutData [%d] (%d) %.10s\n", offset, remaining, &p[offset]); Py_BEGIN_ALLOW_THREADS ret = SQLPutData(cur->hstmt, (SQLPOINTER)&p[offset], remaining); @@ -723,7 +723,7 @@ static PyObject* execute(Cursor* cur, PyObject* pSql, PyObject* params, bool ski while (offset < cb); } #if PY_MAJOR_VERSION < 3 - else if (PyBuffer_Check(pInfo->pObject)) + else if (pInfo->pObject && PyBuffer_Check(pInfo->pObject)) { // Buffers can have multiple segments, so we might need multiple writes. Looping through buffers isn't // difficult, but we've wrapped it up in an iterator object to keep this loop simple. @@ -741,6 +741,65 @@ static PyObject* execute(Cursor* cur, PyObject* pSql, PyObject* params, bool ski } } #endif + else if (pInfo->ParameterType == SQL_SS_TABLE) + { + // TVP + // Need to convert its columns into the bound row buffers + int hasTvpRows = 0; + if (pInfo->curTvpRow < PySequence_Length(pInfo->pObject)) + { + PyObject *tvpRow = PySequence_GetItem(pInfo->pObject, pInfo->curTvpRow); + Py_XDECREF(tvpRow); + for (Py_ssize_t i = 0; i < PySequence_Size(tvpRow); i++) + { + struct ParamInfo newParam; + struct ParamInfo *prevParam = pInfo->nested + i; + PyObject *cell = PySequence_GetItem(tvpRow, i); + Py_XDECREF(cell); + memset(&newParam, 0, sizeof(newParam)); + if (!GetParameterInfo(cur, i, cell, newParam, true)) + { + // Error converting object + FreeParameterData(cur); + return NULL; + } + if (newParam.ValueType != prevParam->ValueType || + newParam.ParameterType != prevParam->ParameterType) + { + FreeParameterData(cur); + return RaiseErrorV(0, ProgrammingError, "Type mismatch between TVP row values"); + } + if (prevParam->allocated) + pyodbc_free(prevParam->ParameterValuePtr); + Py_XDECREF(prevParam->pObject); + newParam.BufferLength = newParam.StrLen_or_Ind; + newParam.StrLen_or_Ind = SQL_DATA_AT_EXEC; + Py_INCREF(cell); + newParam.pObject = cell; + *prevParam = newParam; + if(prevParam->ParameterValuePtr == &newParam.Data) + { + prevParam->ParameterValuePtr = &prevParam->Data; + } + } + pInfo->curTvpRow++; + hasTvpRows = 1; + } + Py_BEGIN_ALLOW_THREADS + ret = SQLPutData(cur->hstmt, hasTvpRows ? (SQLPOINTER)1 : 0, hasTvpRows); + Py_END_ALLOW_THREADS + if (!SQL_SUCCEEDED(ret)) + return RaiseErrorFromHandle(cur->cnxn, "SQLPutData", cur->cnxn->hdbc, cur->hstmt); + } + else + { + // TVP column sent as DAE + Py_BEGIN_ALLOW_THREADS + ret = SQLPutData(cur->hstmt, pInfo->ParameterValuePtr, pInfo->BufferLength); + Py_END_ALLOW_THREADS + if (!SQL_SUCCEEDED(ret)) + return RaiseErrorFromHandle(cur->cnxn, "SQLPutData", cur->cnxn->hdbc, cur->hstmt); + } ret = SQL_NEED_DATA; } } diff --git a/src/cursor.h b/src/cursor.h index 4b3176cb..c39a8d3a 100644 --- a/src/cursor.h +++ b/src/cursor.h @@ -64,6 +64,10 @@ struct ParamInfo // written to each SQLPutData call. (It is not clear if they are limited // like SQLBindParameter or not.) + // For TVPs, the nested descriptors and current row. + struct ParamInfo *nested; + SQLLEN curTvpRow; + // Optional data. If used, ParameterValuePtr will point into this. union { diff --git a/src/params.cpp b/src/params.cpp index 1fc62587..bdeb9119 100644 --- a/src/params.cpp +++ b/src/params.cpp @@ -11,8 +11,8 @@ #include "wrapper.h" #include "textenc.h" #include "pyodbcmodule.h" -#include "params.h" #include "cursor.h" +#include "params.h" #include "connection.h" #include "buffer.h" #include "errors.h" @@ -598,6 +598,8 @@ static void FreeInfos(ParamInfo* a, Py_ssize_t count) { if (a[i].allocated) pyodbc_free(a[i].ParameterValuePtr); + if (a[i].ParameterType == SQL_SS_TABLE && a[i].nested) + FreeInfos(a[i].nested, a[i].maxlength); Py_XDECREF(a[i].pObject); } pyodbc_free(a); @@ -626,17 +628,17 @@ static bool GetNullBinaryInfo(Cursor* cur, Py_ssize_t index, ParamInfo& info) #if PY_MAJOR_VERSION >= 3 -static bool GetBytesInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetBytesInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { // The Python 3 version that writes bytes as binary data. Py_ssize_t cb = PyBytes_GET_SIZE(param); info.ValueType = SQL_C_BINARY; - info.ColumnSize = (SQLUINTEGER)max(cb, 1); + info.ColumnSize = isTVP?0:(SQLUINTEGER)max(cb, 1); SQLLEN maxlength = cur->cnxn->GetMaxLength(info.ValueType); - if (maxlength == 0 || cb <= maxlength) + if (maxlength == 0 || cb <= maxlength || isTVP) { info.ParameterType = SQL_VARBINARY; info.StrLen_or_Ind = cb; @@ -660,7 +662,7 @@ static bool GetBytesInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamIn #endif #if PY_MAJOR_VERSION < 3 -static bool GetStrInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetStrInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { const TextEnc& enc = cur->cnxn->str_enc; @@ -668,7 +670,7 @@ static bool GetStrInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo Py_ssize_t cch = PyString_GET_SIZE(param); - info.ColumnSize = (SQLUINTEGER)max(cch, 1); + info.ColumnSize = isTVP?0:(SQLUINTEGER)max(cch, 1); Object encoded; @@ -699,7 +701,7 @@ static bool GetStrInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo info.pObject = encoded.Detach(); SQLLEN maxlength = cur->cnxn->GetMaxLength(info.ValueType); - if (maxlength == 0 || cb <= maxlength) + if (maxlength == 0 || cb <= maxlength || isTVP) { info.ParameterType = (enc.ctype == SQL_C_CHAR) ? SQL_VARCHAR : SQL_WVARCHAR; info.ParameterValuePtr = PyBytes_AS_STRING(info.pObject); @@ -720,7 +722,7 @@ static bool GetStrInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo #endif -static bool GetUnicodeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetUnicodeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { const TextEnc& enc = cur->cnxn->unicode_enc; @@ -750,13 +752,13 @@ static bool GetUnicodeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Param denom = 4; } - info.ColumnSize = (SQLUINTEGER)max(cb / denom, 1); + info.ColumnSize = isTVP?0:(SQLUINTEGER)max(cb / denom, 1); info.pObject = encoded.Detach(); SQLLEN maxlength = cur->cnxn->GetMaxLength(enc.ctype); - if (maxlength == 0 || cb <= maxlength) + if (maxlength == 0 || cb <= maxlength || isTVP) { info.ParameterType = (enc.ctype == SQL_C_CHAR) ? SQL_VARCHAR : SQL_WVARCHAR; info.ParameterValuePtr = PyBytes_AS_STRING(info.pObject); @@ -850,27 +852,45 @@ static bool GetTimeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInf } #if PY_MAJOR_VERSION < 3 -static bool GetIntInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetIntInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { - info.Data.l = PyInt_AsLong(param); + if(isTVP) + { + PyErr_Clear(); + info.Data.i64 = (INT64)PyLong_AsLongLong(param); + if (!PyErr_Occurred()) + { + info.ValueType = SQL_C_SBIGINT; + info.ParameterType = SQL_BIGINT; + info.ParameterValuePtr = &info.Data.i64; + info.StrLen_or_Ind = 8; + + return true; + } + + return false; + } + info.Data.i64 = (INT64)PyLong_AsLongLong(param); + #if LONG_BIT == 64 info.ValueType = SQL_C_SBIGINT; // info.ValueType = SQL_C_LONG; info.ParameterType = SQL_BIGINT; + info.StrLen_or_Ind = 8; #elif LONG_BIT == 32 info.ValueType = SQL_C_LONG; info.ParameterType = SQL_INTEGER; + info.StrLen_or_Ind = 4; #else #error Unexpected LONG_BIT value #endif - info.ParameterValuePtr = &info.Data.l; return true; } #endif -static bool GetLongInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetLongInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { // Try to use integer when possible. BIGINT is not always supported and is a "special // case" for some drivers. @@ -883,12 +903,18 @@ static bool GetLongInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInf // it is a 'long int', but some drivers run into trouble at high values. We'll use // SQL_INTEGER as an optimization for smaller values and rely on BIGINT - info.Data.l = PyLong_AsLong(param); - if (!PyErr_Occurred() && (info.Data.l <= 0x7FFFFFFF)) + INT64 val = (INT64)PyLong_AsLongLong(param); + + if (!PyErr_Occurred() && !isTVP && (val >= -2147483648 && val <= 2147483647)) { - info.ValueType = SQL_C_LONG; - info.ParameterType = SQL_INTEGER; - info.ParameterValuePtr = &info.Data.l; + info.Data.l = PyLong_AsLong(param); + if (!PyErr_Occurred()) + { + info.ValueType = SQL_C_LONG; + info.ParameterType = SQL_INTEGER; + info.ParameterValuePtr = &info.Data.l; + info.StrLen_or_Ind = 4; + } } else { @@ -899,6 +925,7 @@ static bool GetLongInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInf info.ValueType = SQL_C_SBIGINT; info.ParameterType = SQL_BIGINT; info.ParameterValuePtr = &info.Data.i64; + info.StrLen_or_Ind = 8; } } @@ -914,6 +941,7 @@ static bool GetFloatInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamIn info.ParameterType = SQL_DOUBLE; info.ParameterValuePtr = &info.Data.dbl; info.ColumnSize = 15; + info.StrLen_or_Ind = sizeof(double); return true; } @@ -1015,6 +1043,7 @@ static bool GetUUIDInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInf if (!b) return false; memcpy(info.ParameterValuePtr, PyBytes_AS_STRING(b.Get()), sizeof(SQLGUID)); + info.StrLen_or_Ind = sizeof(SQLGUID); return true; } @@ -1117,19 +1146,20 @@ static bool GetBufferInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamI #endif #if PY_VERSION_HEX >= 0x02060000 -static bool GetByteArrayInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +static bool GetByteArrayInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { info.ValueType = SQL_C_BINARY; Py_ssize_t cb = PyByteArray_Size(param); SQLLEN maxlength = cur->cnxn->GetMaxLength(info.ValueType); - if (maxlength == 0 || cb <= maxlength) + + if (maxlength == 0 || cb <= maxlength || isTVP) { info.ParameterType = SQL_VARBINARY; info.ParameterValuePtr = (SQLPOINTER)PyByteArray_AsString(param); info.BufferLength = (SQLINTEGER)cb; - info.ColumnSize = (SQLUINTEGER)max(cb, 1); + info.ColumnSize = isTVP?0:(SQLUINTEGER)max(cb, 1); info.StrLen_or_Ind = (SQLINTEGER)cb; } else @@ -1147,7 +1177,49 @@ static bool GetByteArrayInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Par } #endif -static bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info) + +// TVP +static bool GetTableInfo(Cursor *cur, Py_ssize_t index, PyObject* param, ParamInfo& info) +{ + int nskip = 0; + Py_ssize_t nrows = PySequence_Size(param); + if (nrows > 0) + { + PyObject *cell0 = PySequence_GetItem(param, 0); + Py_XDECREF(cell0); + if (PyBytes_Check(cell0) || PyUnicode_Check(cell0)) + { + SQLHDESC desc; + PyObject *tvpname = PyCodec_Encode(cell0, "UTF-16LE", 0); + SQLGetStmtAttr(cur->hstmt, SQL_ATTR_IMP_PARAM_DESC, &desc, 0, 0); + SQLSetDescFieldW(desc, index + 1, SQL_CA_SS_TYPE_NAME, (SQLPOINTER)PyBytes_AsString(tvpname), PyBytes_Size(tvpname)); + nskip++; + } + } + nrows -= nskip; + + if (!nskip) + { + // Need to describe in order to fill in IPD with the TVP's type name, because user has not provided it + SQLSMALLINT tvptype; + SQLDescribeParam(cur->hstmt, index + 1, &tvptype, 0, 0, 0); + } + + info.pObject = param; + Py_INCREF(param); + info.ValueType = SQL_C_BINARY; + info.ParameterType = SQL_SS_TABLE; + info.ColumnSize = nrows; + info.DecimalDigits = 0; + info.ParameterValuePtr = &info; + info.BufferLength = 0; + info.curTvpRow = nskip; + info.StrLen_or_Ind = SQL_DATA_AT_EXEC; + info.allocated = false; + return true; +} + +bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP) { // Determines the type of SQL parameter that will be used for this parameter based on the Python data type. // @@ -1161,14 +1233,14 @@ static bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Par #if PY_MAJOR_VERSION >= 3 if (PyBytes_Check(param)) - return GetBytesInfo(cur, index, param, info); + return GetBytesInfo(cur, index, param, info, isTVP); #else if (PyBytes_Check(param)) - return GetStrInfo(cur, index, param, info); + return GetStrInfo(cur, index, param, info, isTVP); #endif if (PyUnicode_Check(param)) - return GetUnicodeInfo(cur, index, param, info); + return GetUnicodeInfo(cur, index, param, info, isTVP); if (PyBool_Check(param)) return GetBooleanInfo(cur, index, param, info); @@ -1183,19 +1255,19 @@ static bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Par return GetTimeInfo(cur, index, param, info); if (PyLong_Check(param)) - return GetLongInfo(cur, index, param, info); + return GetLongInfo(cur, index, param, info, isTVP); if (PyFloat_Check(param)) return GetFloatInfo(cur, index, param, info); #if PY_VERSION_HEX >= 0x02060000 if (PyByteArray_Check(param)) - return GetByteArrayInfo(cur, index, param, info); + return GetByteArrayInfo(cur, index, param, info, isTVP); #endif #if PY_MAJOR_VERSION < 3 if (PyInt_Check(param)) - return GetIntInfo(cur, index, param, info); + return GetIntInfo(cur, index, param, info, isTVP); if (PyBuffer_Check(param)) return GetBufferInfo(cur, index, param, info); @@ -1218,6 +1290,9 @@ static bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Par if (cls != 0) return GetUUIDInfo(cur, index, param, info, cls); + if (PySequence_Check(param)) + return GetTableInfo(cur, index, param, info); + RaiseErrorV("HY105", ProgrammingError, "Invalid parameter type. param-index=%zd param-type=%s", index, Py_TYPE(param)->tp_name); return false; } @@ -1333,7 +1408,7 @@ bool BindParameter(Cursor* cur, Py_ssize_t index, ParamInfo& info) SQLRETURN ret = -1; Py_BEGIN_ALLOW_THREADS ret = SQLBindParameter(cur->hstmt, (SQLUSMALLINT)(index + 1), SQL_PARAM_INPUT, - info.ValueType, sqltype, colsize, scale, info.ParameterValuePtr, info.BufferLength, &info.StrLen_or_Ind); + info.ValueType, sqltype, colsize, scale, sqltype == SQL_SS_TABLE ? 0 : info.ParameterValuePtr, info.BufferLength, &info.StrLen_or_Ind); Py_END_ALLOW_THREADS; if (GetConnection(cur)->hdbc == SQL_NULL_HANDLE) @@ -1349,6 +1424,96 @@ bool BindParameter(Cursor* cur, Py_ssize_t index, ParamInfo& info) return false; } + // This is a TVP. Enter and bind its parameters, allocate descriptors for its columns (all as DAE) + if (sqltype == SQL_SS_TABLE) + { + SQLHDESC desc; + SQLGetStmtAttr(cur->hstmt, SQL_ATTR_APP_PARAM_DESC, &desc, 0, 0); + SQLSetDescField(desc, index + 1, SQL_DESC_DATA_PTR, (SQLPOINTER)info.ParameterValuePtr, 0); + + int err = 0; + ret = SQLSetStmtAttr(cur->hstmt, SQL_SOPT_SS_PARAM_FOCUS, (SQLPOINTER)(index + 1), SQL_IS_INTEGER); + if (!SQL_SUCCEEDED(ret)) + { + RaiseErrorFromHandle(cur->cnxn, "SQLSetStmtAttr", GetConnection(cur)->hdbc, cur->hstmt); + return false; + } + + Py_ssize_t i = PySequence_Size(info.pObject) - info.ColumnSize; + Py_ssize_t ncols = 0; + while (i < PySequence_Size(info.pObject)) + { + PyObject *row = PySequence_GetItem(info.pObject, i); + Py_XDECREF(row); + if (!PySequence_Check(row)) + { + RaiseErrorV(0, ProgrammingError, "A TVP's rows must be Sequence objects."); + err = 1; + break; + } + if(ncols && ncols != PySequence_Size(row)) + { + RaiseErrorV(0, ProgrammingError, "A TVP's rows must all be the same size."); + err = 1; + break; + } + ncols = PySequence_Size(row); + i++; + } + if (!ncols) + { + // TVP has no columns --- is null + info.nested = 0; + } + else + { + PyObject *row = PySequence_GetItem(info.pObject, PySequence_Size(info.pObject) - info.ColumnSize); + Py_XDECREF(row); + + info.nested = (ParamInfo*)pyodbc_malloc(ncols * sizeof(ParamInfo)); + info.maxlength = ncols; + memset(info.nested, 0, ncols * sizeof(ParamInfo)); + + for(i=0;ihstmt, (SQLUSMALLINT)(i + 1), SQL_PARAM_INPUT, + info.nested[i].ValueType, info.nested[i].ParameterType, + info.nested[i].ColumnSize, info.nested[i].DecimalDigits, + info.nested + i, info.nested[i].BufferLength, &info.nested[i].StrLen_or_Ind); + Py_END_ALLOW_THREADS; + if (GetConnection(cur)->hdbc == SQL_NULL_HANDLE) + { + // The connection was closed by another thread in the ALLOW_THREADS block above. + RaiseErrorV(0, ProgrammingError, "The cursor's connection was closed."); + return false; + } + + if (!SQL_SUCCEEDED(ret)) + { + RaiseErrorFromHandle(cur->cnxn, "SQLBindParameter", GetConnection(cur)->hdbc, cur->hstmt); + return false; + } + } + } + + ret = SQLSetStmtAttr(cur->hstmt, SQL_SOPT_SS_PARAM_FOCUS, 0, SQL_IS_INTEGER); + if (!SQL_SUCCEEDED(ret)) + { + RaiseErrorFromHandle(cur->cnxn, "SQLSetStmtAttr", GetConnection(cur)->hdbc, cur->hstmt); + return false; + } + + if (err) + return false; + } + return true; } @@ -1498,7 +1663,7 @@ bool PrepareAndBind(Cursor* cur, PyObject* pSql, PyObject* original_params, bool for (Py_ssize_t i = 0; i < cParams; i++) { Object param(PySequence_GetItem(original_params, i + params_offset)); - if (!GetParameterInfo(cur, i, param, cur->paramInfos[i])) + if (!GetParameterInfo(cur, i, param, cur->paramInfos[i], false)) { FreeInfos(cur->paramInfos, cParams); cur->paramInfos = 0; diff --git a/src/params.h b/src/params.h index 3782a7ee..c8e77de2 100644 --- a/src/params.h +++ b/src/params.h @@ -8,6 +8,7 @@ struct Cursor; bool PrepareAndBind(Cursor* cur, PyObject* pSql, PyObject* params, bool skip_first); bool ExecuteMulti(Cursor* cur, PyObject* pSql, PyObject* paramArrayObj); +bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info, bool isTVP); void FreeParameterData(Cursor* cur); void FreeParameterInfo(Cursor* cur); diff --git a/src/pyodbc.h b/src/pyodbc.h index 0ce36e0e..6bef3e66 100644 --- a/src/pyodbc.h +++ b/src/pyodbc.h @@ -70,6 +70,18 @@ typedef int Py_ssize_t; #define _countof(a) (sizeof(a) / sizeof(a[0])) #endif +#ifndef SQL_SS_TABLE +#define SQL_SS_TABLE -153 +#endif + +#ifndef SQL_SOPT_SS_PARAM_FOCUS +#define SQL_SOPT_SS_PARAM_FOCUS 1236 +#endif + +#ifndef SQL_CA_SS_TYPE_NAME +#define SQL_CA_SS_TYPE_NAME 1227 +#endif + inline bool IsSet(DWORD grf, DWORD flags) { return (grf & flags) == flags; diff --git a/tests2/sqlservertests.py b/tests2/sqlservertests.py index 68666dfb..d4440a40 100755 --- a/tests2/sqlservertests.py +++ b/tests2/sqlservertests.py @@ -1638,6 +1638,110 @@ def test_emoticons(self): self.assertEqual(result, v) + def test_tvp(self): + # https://github.com/mkleehammer/pyodbc/issues/290 + # + # pyodbc supports queries with table valued parameters in sql server + # + + self.cursor.execute("DROP PROCEDURE IF EXISTS SelectTVP") + self.cursor.commit() + self.cursor.execute("DROP TYPE IF EXISTS TestTVP") + self.cursor.commit() + + query = "CREATE TYPE TestTVP AS TABLE("\ + "c01 VARCHAR(255),"\ + "c02 VARCHAR(MAX),"\ + "c03 VARBINARY(255),"\ + "c04 VARBINARY(MAX),"\ + "c05 BIT,"\ + "c06 DATE,"\ + "c07 TIME,"\ + "c08 DATETIME2(5),"\ + "c09 BIGINT,"\ + "c10 FLOAT,"\ + "c11 NUMERIC(38, 24),"\ + "c12 UNIQUEIDENTIFIER)" + + self.cursor.execute(query) + self.cursor.commit() + self.cursor.execute("CREATE PROCEDURE SelectTVP @TVP TestTVP READONLY AS SELECT * FROM @TVP;") + self.cursor.commit() + + long_string = '' + long_bytearray = [] + for i in range(255): + long_string += chr((i % 95) + 32) + long_bytearray.append(i % 255) + + very_long_string = '' + very_long_bytearray = [] + for i in range(2000000): + very_long_string += chr((i % 95) + 32) + very_long_bytearray.append(i % 255) + + c01 = ['abc', '', long_string] + + c02 = ['abc', '', very_long_string] + + c03 = [bytearray([0xD1, 0xCE, 0xFA, 0xCE]), + bytearray([0x00, 0x01, 0x02, 0x03, 0x04]), + bytearray(long_bytearray)] + + c04 = [bytearray([0x0F, 0xF1, 0xCE, 0xCA, 0xFE]), + bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05]), + bytearray(very_long_bytearray)] + + c05 = [1, 0, 1] + + c06 = [date(1997, 8, 29), + date(1, 1, 1), + date(9999, 12, 31)] + + c07 = [time(9, 13, 39), + time(0, 0, 0), + time(23, 59, 59)] + + c08 = [datetime(2018, 11, 13, 13, 33, 26, 298420), + datetime(1, 1, 1, 0, 0, 0, 0), + datetime(9999, 12, 31, 23, 59, 59, 999990)] + + c09 = [1234567, -9223372036854775808, 9223372036854775807] + + c10 = [3.14, -1.79E+308, 1.79E+308] + + c11 = [Decimal('31234567890123.141243449787580175325274'), + Decimal( '0.000000000000000000000001'), + Decimal('99999999999999.999999999999999999999999')] + + c12 = ['4FE34A93-E574-04CC-200A-353F0D1770B1', + '33F7504C-2BAC-1B83-01D1-7434A7BA6A17', + 'FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF'] + + param_array = [] + + for i in range (3): + param_array.append([c01[i], c02[i], c03[i], c04[i], c05[i], c06[i], c07[i], c08[i], c09[i], c10[i], c11[i], c12[i]]) + + success = True + + try: + result_array = self.cursor.execute("exec SelectTVP ?",[param_array]).fetchall() + except Exception as ex: + print("Failed to execute SelectTVP") + print("Exception: [" + type(ex).__name__ + "]" , ex.args) + + success = False + else: + for r in range(len(result_array)): + for c in range(len(result_array[r])): + if(result_array[r][c] != param_array[r][c]): + print("Mismatch at row " + str(r+1) + ", column " + str(c+1) + "; expected:", param_array[r][c] , " received:", result_array[r][c]) + success = False + + self.assertEqual(success, True) + + def main(): from optparse import OptionParser parser = OptionParser(usage=usage) diff --git a/tests3/sqlservertests.py b/tests3/sqlservertests.py index c2547a92..468bd9ba 100644 --- a/tests3/sqlservertests.py +++ b/tests3/sqlservertests.py @@ -1565,6 +1565,109 @@ def test_emoticons(self): self.assertEqual(result, v) + def test_tvp(self): + # https://github.com/mkleehammer/pyodbc/issues/290 + # + # pyodbc supports queries with table valued parameters in sql server + # + + self.cursor.execute("DROP PROCEDURE IF EXISTS SelectTVP") + self.cursor.commit() + self.cursor.execute("DROP TYPE IF EXISTS TestTVP") + self.cursor.commit() + + query = "CREATE TYPE TestTVP AS TABLE("\ + "c01 VARCHAR(255),"\ + "c02 VARCHAR(MAX),"\ + "c03 VARBINARY(255),"\ + "c04 VARBINARY(MAX),"\ + "c05 BIT,"\ + "c06 DATE,"\ + "c07 TIME,"\ + "c08 DATETIME2(5),"\ + "c09 BIGINT,"\ + "c10 FLOAT,"\ + "c11 NUMERIC(38, 24),"\ + "c12 UNIQUEIDENTIFIER)" + + self.cursor.execute(query) + self.cursor.commit() + self.cursor.execute("CREATE PROCEDURE SelectTVP @TVP TestTVP READONLY AS SELECT * FROM @TVP;") + self.cursor.commit() + + long_string = '' + long_bytearray = [] + for i in range(255): + long_string += chr((i % 95) + 32) + long_bytearray.append(i % 255) + + very_long_string = '' + very_long_bytearray = [] + for i in range(2000000): + very_long_string += chr((i % 95) + 32) + very_long_bytearray.append(i % 255) + + c01 = ['abc', '', long_string] + + c02 = ['abc', '', very_long_string] + + c03 = [bytearray([0xD1, 0xCE, 0xFA, 0xCE]), + bytearray([0x00, 0x01, 0x02, 0x03, 0x04]), + bytearray(long_bytearray)] + + c04 = [bytearray([0x0F, 0xF1, 0xCE, 0xCA, 0xFE]), + bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05]), + bytearray(very_long_bytearray)] + + c05 = [1, 0, 1] + + c06 = [date(1997, 8, 29), + date(1, 1, 1), + date(9999, 12, 31)] + + c07 = [time(9, 13, 39), + time(0, 0, 0), + time(23, 59, 59)] + + c08 = [datetime(2018, 11, 13, 13, 33, 26, 298420), + datetime(1, 1, 1, 0, 0, 0, 0), + datetime(9999, 12, 31, 23, 59, 59, 999990)] + + c09 = [1234567, -9223372036854775808, 9223372036854775807] + + c10 = [3.14, -1.79E+308, 1.79E+308] + + c11 = [Decimal('31234567890123.141243449787580175325274'), + Decimal( '0.000000000000000000000001'), + Decimal('99999999999999.999999999999999999999999')] + + c12 = ['4FE34A93-E574-04CC-200A-353F0D1770B1', + '33F7504C-2BAC-1B83-01D1-7434A7BA6A17', + 'FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF'] + + param_array = [] + + for i in range (3): + param_array.append([c01[i], c02[i], c03[i], c04[i], c05[i], c06[i], c07[i], c08[i], c09[i], c10[i], c11[i], c12[i]]) + + success = True + + try: + result_array = self.cursor.execute("exec SelectTVP ?",[param_array]).fetchall() + except Exception as ex: + print("Failed to execute SelectTVP") + print("Exception: [" + type(ex).__name__ + "]" , ex.args) + + success = False + else: + for r in range(len(result_array)): + for c in range(len(result_array[r])): + if(result_array[r][c] != param_array[r][c]): + print("Mismatch at row " + str(r+1) + ", column " + str(c+1) + "; expected:", param_array[r][c] , " received:", result_array[r][c]) + success = False + + self.assertEqual(success, True) + def main(): from optparse import OptionParser parser = OptionParser(usage=usage)