diff --git a/src/errors.cpp b/src/errors.cpp index 00bc1df5..2a98079b 100644 --- a/src/errors.cpp +++ b/src/errors.cpp @@ -183,6 +183,36 @@ PyObject* RaiseErrorFromHandle(Connection *conn, const char* szFunction, HDBC hd } +inline void CopySqlState(const ODBCCHAR* src, char* dest) +{ + // Copies a SQLSTATE read as SQLWCHAR into a character buffer. We know that SQLSTATEs are + // composed of ASCII characters and we need one standard to compare when choosing + // exceptions. + // + // Strangely, even when the error messages are UTF-8, PostgreSQL and MySQL encode the + // sqlstate as UTF-16LE. We'll simply copy all non-zero bytes, with some checks for + // running off the end of the buffers which will work for ASCII, UTF8, and UTF16 LE & BE. + // It would work for UTF32 if I increase the size of the ODBCCHAR buffer to handle it. + // + // (In the worst case, if a driver does something totally weird, we'll have an incomplete + // SQLSTATE.) + // + + const char* pchSrc = (const char*)src; + const char* pchSrcMax = pchSrc + sizeof(ODBCCHAR) * 5; + char* pchDest = dest; // Where we are copying into dest + char* pchDestMax = dest + 5; // We know a SQLSTATE is 5 characters long + + while (pchDest < pchDestMax && pchSrc < pchSrcMax) + { + if (*pchSrc) + *pchDest++ = *pchSrc; + pchSrc++; + } + *pchDest = 0; +} + + PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc, HSTMT hstmt) { TRACE("In RaiseError(%s)!\n", szFunction); @@ -208,9 +238,6 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc ODBCCHAR sqlstateT[6]; ODBCCHAR szMsg[1024]; - PyObject* pMsg = 0; - PyObject* pMsgPart = 0; - if (hstmt != SQL_NULL_HANDLE) { nHandleType = SQL_HANDLE_STMT; @@ -232,6 +259,8 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc SQLSMALLINT iRecord = 1; + Object msg; + for (;;) { szMsg[0] = 0; @@ -251,33 +280,32 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc // For now, default to UTF-16LE if this is not in the context of a connection. // Note that this will not work if the DM is using a different wide encoding (e.g. UTF-32). - const char *unicode_enc = conn ? conn->unicode_enc.name : "utf-16-le"; + const char *unicode_enc = conn ? conn->metadata_enc.name : "utf-16-le"; Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(ODBCCHAR), unicode_enc, "strict")); - Object stateStr(PyUnicode_Decode((char*)sqlstateT, 5 * sizeof(ODBCCHAR), unicode_enc, "strict")); - if (cchMsg != 0) + if (cchMsg != 0 && msgStr.Get()) { if (iRecord == 1) { - // This is the first error message, so save the SQLSTATE for determining the exception class and append - // the calling function name. - - memcpy(sqlstate, sqlstateT, sizeof(sqlstate[0]) * _countof(sqlstate)); - - pMsg = PyUnicode_FromFormat("[%V] %V (%ld) (%s)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError, szFunction); - if (pMsg == 0) + // This is the first error message, so save the SQLSTATE for determining the + // exception class and append the calling function name. + CopySqlState(sqlstateT, sqlstate); + msg = PyUnicode_FromFormat("[%s] %V (%ld) (%s)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError, szFunction); + if (!msg) return 0; } else { // This is not the first error message, so append to the existing one. - pMsgPart = PyString_FromFormat("; [%V] %V (%ld)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError); - if (pMsgPart == 0) - { - Py_XDECREF(pMsg); - return 0; - } - PyString_ConcatAndDel(&pMsg, pMsgPart); + Object more(PyUnicode_FromFormat("; [%s] %V (%ld)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError)); + if (!more) + break; // Something went wrong, but we'll return the msg we have so far + + Object both(PyUnicode_Concat(msg, more)); + if (!both) + break; + + msg = both.Detach(); } } @@ -289,20 +317,20 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc #endif } - if (pMsg == 0) + if (!msg || PyUnicode_GetSize(msg.Get()) == 0) { // This only happens using unixODBC. (Haven't tried iODBC yet.) Either the driver or the driver manager is // buggy and has signaled a fault without recording error information. sqlstate[0] = '\0'; - pMsg = PyString_FromString(DEFAULT_ERROR); - if (pMsg == 0) + msg = PyString_FromString(DEFAULT_ERROR); + if (!msg) { PyErr_NoMemory(); return 0; } } - return GetError(sqlstate, 0, pMsg); + return GetError(sqlstate, 0, msg.Detach()); } diff --git a/src/params.cpp b/src/params.cpp index 18f7bc6a..bbdc4d59 100644 --- a/src/params.cpp +++ b/src/params.cpp @@ -1530,7 +1530,7 @@ bool ExecuteMulti(Cursor* cur, PyObject* pSql, PyObject* paramArrayObj) SQLGetStmtAttr(cur->hstmt, SQL_ATTR_APP_PARAM_DESC, &desc, 0, 0); SQLSetDescField(desc, i + 1, SQL_DESC_TYPE, (SQLPOINTER)SQL_C_NUMERIC, 0); SQLSetDescField(desc, i + 1, SQL_DESC_PRECISION, (SQLPOINTER)cur->paramInfos[i].ColumnSize, 0); - SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)cur->paramInfos[i].DecimalDigits, 0); + SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)(uintptr_t)cur->paramInfos[i].DecimalDigits, 0); SQLSetDescField(desc, i + 1, SQL_DESC_DATA_PTR, bindptr, 0); } bindptr += cur->paramInfos[i].BufferLength + sizeof(SQLLEN); diff --git a/tests2/sqlservertests.py b/tests2/sqlservertests.py index 410f4069..b15c021c 100755 --- a/tests2/sqlservertests.py +++ b/tests2/sqlservertests.py @@ -1531,6 +1531,12 @@ def test_prepare_cleanup(self): self.cursor.execute("select top 1 name from sysobjects where name = ?", "bogus") self.cursor.fetchone() + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") def main(): diff --git a/tests3/pgtests.py b/tests3/pgtests.py index 0cf44e52..ea69ad51 100755 --- a/tests3/pgtests.py +++ b/tests3/pgtests.py @@ -4,30 +4,32 @@ from __future__ import print_function -import sys, os, re, uuid +import uuid import unittest from decimal import Decimal from testutils import * _TESTSTR = '0123456789-abcdefghijklmnopqrstuvwxyz-' + def _generate_test_string(length): """ Returns a string of composed of `seed` to make a string `length` characters long. - To enhance performance, there are 3 ways data is read, based on the length of the value, so most data types are - tested with 3 lengths. This function helps us generate the test data. + To enhance performance, there are 3 ways data is read, based on the length of the value, so + most data types are tested with 3 lengths. This function helps us generate the test data. - We use a recognizable data set instead of a single character to make it less likely that "overlap" errors will - be hidden and to help us manually identify where a break occurs. + We use a recognizable data set instead of a single character to make it less likely that + "overlap" errors will be hidden and to help us manually identify where a break occurs. """ if length <= len(_TESTSTR): return _TESTSTR[:length] - c = int((length + len(_TESTSTR)-1) / len(_TESTSTR)) + c = int((length + len(_TESTSTR) - 1) / len(_TESTSTR)) v = _TESTSTR * c return v[:length] + class PGTestCase(unittest.TestCase): INTEGERS = [ -1, 0, 1, 0x7FFFFFFF ] @@ -190,6 +192,16 @@ def test_varchar_many(self): self.assertEqual(v2, row.c2) self.assertEqual(v3, row.c3) + def test_chinese(self): + v = '我的' + self.cursor.execute("SELECT N'我的' AS name") + row = self.cursor.fetchone() + self.assertEqual(row[0], v) + + self.cursor.execute("SELECT N'我的' AS name") + rows = self.cursor.fetchall() + self.assertEqual(rows[0][0], v) + # # bytea # @@ -482,6 +494,14 @@ def test_autocommit(self): othercnxn.autocommit = False self.assertEqual(othercnxn.autocommit, False) + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") + + def test_cnxn_set_attr_before(self): # I don't have a getattr right now since I don't have a table telling me what kind of # value to expect. For now just make sure it doesn't crash. diff --git a/tests3/sqlservertests.py b/tests3/sqlservertests.py index d7bf7b33..09614bfe 100644 --- a/tests3/sqlservertests.py +++ b/tests3/sqlservertests.py @@ -1357,6 +1357,13 @@ def test_decode_meta(self): self.cursor.execute('select a as "Tipología" from t1') self.assertEqual(self.cursor.description[0][0], "Tipología") + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") + def main(): from optparse import OptionParser