From 82770734e845fdd1137762e250461708f837d736 Mon Sep 17 00:00:00 2001 From: Michael Kleehammer Date: Sat, 12 Aug 2017 16:44:06 -0500 Subject: [PATCH] Fixes for Unicode diagnostic messages for PostgreSQL & MySQL Some drivers don't encode SQLSTATE and error messages with the same encoding. So far it looks like everyone might use UTF16LE. This code handles UTF8, UTF16LE, and UTF16BE with no configuration. (It does not handle UTF32.) The previous change converted the error message from bytes to Unicode, but missed the handling of the 2nd error message. When an attempt to concatenate was made, PyString_ConcatAndDel would fail causing the HY0000 error. --- src/errors.cpp | 76 +++++++++++++++++++++++++++------------- src/params.cpp | 2 +- tests2/sqlservertests.py | 6 ++++ tests3/pgtests.py | 32 +++++++++++++---- tests3/sqlservertests.py | 7 ++++ 5 files changed, 92 insertions(+), 31 deletions(-) diff --git a/src/errors.cpp b/src/errors.cpp index 00bc1df5..2a98079b 100644 --- a/src/errors.cpp +++ b/src/errors.cpp @@ -183,6 +183,36 @@ PyObject* RaiseErrorFromHandle(Connection *conn, const char* szFunction, HDBC hd } +inline void CopySqlState(const ODBCCHAR* src, char* dest) +{ + // Copies a SQLSTATE read as SQLWCHAR into a character buffer. We know that SQLSTATEs are + // composed of ASCII characters and we need one standard to compare when choosing + // exceptions. + // + // Strangely, even when the error messages are UTF-8, PostgreSQL and MySQL encode the + // sqlstate as UTF-16LE. We'll simply copy all non-zero bytes, with some checks for + // running off the end of the buffers which will work for ASCII, UTF8, and UTF16 LE & BE. + // It would work for UTF32 if I increase the size of the ODBCCHAR buffer to handle it. + // + // (In the worst case, if a driver does something totally weird, we'll have an incomplete + // SQLSTATE.) + // + + const char* pchSrc = (const char*)src; + const char* pchSrcMax = pchSrc + sizeof(ODBCCHAR) * 5; + char* pchDest = dest; // Where we are copying into dest + char* pchDestMax = dest + 5; // We know a SQLSTATE is 5 characters long + + while (pchDest < pchDestMax && pchSrc < pchSrcMax) + { + if (*pchSrc) + *pchDest++ = *pchSrc; + pchSrc++; + } + *pchDest = 0; +} + + PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc, HSTMT hstmt) { TRACE("In RaiseError(%s)!\n", szFunction); @@ -208,9 +238,6 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc ODBCCHAR sqlstateT[6]; ODBCCHAR szMsg[1024]; - PyObject* pMsg = 0; - PyObject* pMsgPart = 0; - if (hstmt != SQL_NULL_HANDLE) { nHandleType = SQL_HANDLE_STMT; @@ -232,6 +259,8 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc SQLSMALLINT iRecord = 1; + Object msg; + for (;;) { szMsg[0] = 0; @@ -251,33 +280,32 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc // For now, default to UTF-16LE if this is not in the context of a connection. // Note that this will not work if the DM is using a different wide encoding (e.g. UTF-32). - const char *unicode_enc = conn ? conn->unicode_enc.name : "utf-16-le"; + const char *unicode_enc = conn ? conn->metadata_enc.name : "utf-16-le"; Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(ODBCCHAR), unicode_enc, "strict")); - Object stateStr(PyUnicode_Decode((char*)sqlstateT, 5 * sizeof(ODBCCHAR), unicode_enc, "strict")); - if (cchMsg != 0) + if (cchMsg != 0 && msgStr.Get()) { if (iRecord == 1) { - // This is the first error message, so save the SQLSTATE for determining the exception class and append - // the calling function name. - - memcpy(sqlstate, sqlstateT, sizeof(sqlstate[0]) * _countof(sqlstate)); - - pMsg = PyUnicode_FromFormat("[%V] %V (%ld) (%s)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError, szFunction); - if (pMsg == 0) + // This is the first error message, so save the SQLSTATE for determining the + // exception class and append the calling function name. + CopySqlState(sqlstateT, sqlstate); + msg = PyUnicode_FromFormat("[%s] %V (%ld) (%s)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError, szFunction); + if (!msg) return 0; } else { // This is not the first error message, so append to the existing one. - pMsgPart = PyString_FromFormat("; [%V] %V (%ld)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError); - if (pMsgPart == 0) - { - Py_XDECREF(pMsg); - return 0; - } - PyString_ConcatAndDel(&pMsg, pMsgPart); + Object more(PyUnicode_FromFormat("; [%s] %V (%ld)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError)); + if (!more) + break; // Something went wrong, but we'll return the msg we have so far + + Object both(PyUnicode_Concat(msg, more)); + if (!both) + break; + + msg = both.Detach(); } } @@ -289,20 +317,20 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc #endif } - if (pMsg == 0) + if (!msg || PyUnicode_GetSize(msg.Get()) == 0) { // This only happens using unixODBC. (Haven't tried iODBC yet.) Either the driver or the driver manager is // buggy and has signaled a fault without recording error information. sqlstate[0] = '\0'; - pMsg = PyString_FromString(DEFAULT_ERROR); - if (pMsg == 0) + msg = PyString_FromString(DEFAULT_ERROR); + if (!msg) { PyErr_NoMemory(); return 0; } } - return GetError(sqlstate, 0, pMsg); + return GetError(sqlstate, 0, msg.Detach()); } diff --git a/src/params.cpp b/src/params.cpp index 18f7bc6a..bbdc4d59 100644 --- a/src/params.cpp +++ b/src/params.cpp @@ -1530,7 +1530,7 @@ bool ExecuteMulti(Cursor* cur, PyObject* pSql, PyObject* paramArrayObj) SQLGetStmtAttr(cur->hstmt, SQL_ATTR_APP_PARAM_DESC, &desc, 0, 0); SQLSetDescField(desc, i + 1, SQL_DESC_TYPE, (SQLPOINTER)SQL_C_NUMERIC, 0); SQLSetDescField(desc, i + 1, SQL_DESC_PRECISION, (SQLPOINTER)cur->paramInfos[i].ColumnSize, 0); - SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)cur->paramInfos[i].DecimalDigits, 0); + SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)(uintptr_t)cur->paramInfos[i].DecimalDigits, 0); SQLSetDescField(desc, i + 1, SQL_DESC_DATA_PTR, bindptr, 0); } bindptr += cur->paramInfos[i].BufferLength + sizeof(SQLLEN); diff --git a/tests2/sqlservertests.py b/tests2/sqlservertests.py index 410f4069..b15c021c 100755 --- a/tests2/sqlservertests.py +++ b/tests2/sqlservertests.py @@ -1531,6 +1531,12 @@ def test_prepare_cleanup(self): self.cursor.execute("select top 1 name from sysobjects where name = ?", "bogus") self.cursor.fetchone() + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") def main(): diff --git a/tests3/pgtests.py b/tests3/pgtests.py index 0cf44e52..ea69ad51 100755 --- a/tests3/pgtests.py +++ b/tests3/pgtests.py @@ -4,30 +4,32 @@ from __future__ import print_function -import sys, os, re, uuid +import uuid import unittest from decimal import Decimal from testutils import * _TESTSTR = '0123456789-abcdefghijklmnopqrstuvwxyz-' + def _generate_test_string(length): """ Returns a string of composed of `seed` to make a string `length` characters long. - To enhance performance, there are 3 ways data is read, based on the length of the value, so most data types are - tested with 3 lengths. This function helps us generate the test data. + To enhance performance, there are 3 ways data is read, based on the length of the value, so + most data types are tested with 3 lengths. This function helps us generate the test data. - We use a recognizable data set instead of a single character to make it less likely that "overlap" errors will - be hidden and to help us manually identify where a break occurs. + We use a recognizable data set instead of a single character to make it less likely that + "overlap" errors will be hidden and to help us manually identify where a break occurs. """ if length <= len(_TESTSTR): return _TESTSTR[:length] - c = int((length + len(_TESTSTR)-1) / len(_TESTSTR)) + c = int((length + len(_TESTSTR) - 1) / len(_TESTSTR)) v = _TESTSTR * c return v[:length] + class PGTestCase(unittest.TestCase): INTEGERS = [ -1, 0, 1, 0x7FFFFFFF ] @@ -190,6 +192,16 @@ def test_varchar_many(self): self.assertEqual(v2, row.c2) self.assertEqual(v3, row.c3) + def test_chinese(self): + v = '我的' + self.cursor.execute("SELECT N'我的' AS name") + row = self.cursor.fetchone() + self.assertEqual(row[0], v) + + self.cursor.execute("SELECT N'我的' AS name") + rows = self.cursor.fetchall() + self.assertEqual(rows[0][0], v) + # # bytea # @@ -482,6 +494,14 @@ def test_autocommit(self): othercnxn.autocommit = False self.assertEqual(othercnxn.autocommit, False) + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") + + def test_cnxn_set_attr_before(self): # I don't have a getattr right now since I don't have a table telling me what kind of # value to expect. For now just make sure it doesn't crash. diff --git a/tests3/sqlservertests.py b/tests3/sqlservertests.py index d7bf7b33..09614bfe 100644 --- a/tests3/sqlservertests.py +++ b/tests3/sqlservertests.py @@ -1357,6 +1357,13 @@ def test_decode_meta(self): self.cursor.execute('select a as "Tipología" from t1') self.assertEqual(self.cursor.description[0][0], "Tipología") + def test_exc_integrity(self): + "Make sure an IntegretyError is raised" + # This is really making sure we are properly encoding and comparing the SQLSTATEs. + self.cursor.execute("create table t1(s1 varchar(10) primary key)") + self.cursor.execute("insert into t1 values ('one')") + self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')") + def main(): from optparse import OptionParser