Skip to content

Commit

Permalink
Fixes for Unicode diagnostic messages for PostgreSQL & MySQL
Browse files Browse the repository at this point in the history
Some drivers don't encode SQLSTATE and error messages with the same encoding.  So far it looks
like everyone might use UTF16LE.  This code handles UTF8, UTF16LE, and UTF16BE with no
configuration.  (It does not handle UTF32.)

The previous change converted the error message from bytes to Unicode, but missed the handling
of the 2nd error message.  When an attempt to concatenate was made, PyString_ConcatAndDel would
fail causing the HY0000 error.
  • Loading branch information
mkleehammer committed Aug 12, 2017
1 parent b6679b2 commit 8277073
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 31 deletions.
76 changes: 52 additions & 24 deletions src/errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,36 @@ PyObject* RaiseErrorFromHandle(Connection *conn, const char* szFunction, HDBC hd
}


inline void CopySqlState(const ODBCCHAR* src, char* dest)
{
// Copies a SQLSTATE read as SQLWCHAR into a character buffer. We know that SQLSTATEs are
// composed of ASCII characters and we need one standard to compare when choosing
// exceptions.
//
// Strangely, even when the error messages are UTF-8, PostgreSQL and MySQL encode the
// sqlstate as UTF-16LE. We'll simply copy all non-zero bytes, with some checks for
// running off the end of the buffers which will work for ASCII, UTF8, and UTF16 LE & BE.
// It would work for UTF32 if I increase the size of the ODBCCHAR buffer to handle it.
//
// (In the worst case, if a driver does something totally weird, we'll have an incomplete
// SQLSTATE.)
//

const char* pchSrc = (const char*)src;
const char* pchSrcMax = pchSrc + sizeof(ODBCCHAR) * 5;
char* pchDest = dest; // Where we are copying into dest
char* pchDestMax = dest + 5; // We know a SQLSTATE is 5 characters long

while (pchDest < pchDestMax && pchSrc < pchSrcMax)
{
if (*pchSrc)
*pchDest++ = *pchSrc;
pchSrc++;
}
*pchDest = 0;
}


PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc, HSTMT hstmt)
{
TRACE("In RaiseError(%s)!\n", szFunction);
Expand All @@ -208,9 +238,6 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
ODBCCHAR sqlstateT[6];
ODBCCHAR szMsg[1024];

PyObject* pMsg = 0;
PyObject* pMsgPart = 0;

if (hstmt != SQL_NULL_HANDLE)
{
nHandleType = SQL_HANDLE_STMT;
Expand All @@ -232,6 +259,8 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc

SQLSMALLINT iRecord = 1;

Object msg;

for (;;)
{
szMsg[0] = 0;
Expand All @@ -251,33 +280,32 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc

// For now, default to UTF-16LE if this is not in the context of a connection.
// Note that this will not work if the DM is using a different wide encoding (e.g. UTF-32).
const char *unicode_enc = conn ? conn->unicode_enc.name : "utf-16-le";
const char *unicode_enc = conn ? conn->metadata_enc.name : "utf-16-le";
Object msgStr(PyUnicode_Decode((char*)szMsg, cchMsg * sizeof(ODBCCHAR), unicode_enc, "strict"));
Object stateStr(PyUnicode_Decode((char*)sqlstateT, 5 * sizeof(ODBCCHAR), unicode_enc, "strict"));

if (cchMsg != 0)
if (cchMsg != 0 && msgStr.Get())
{
if (iRecord == 1)
{
// This is the first error message, so save the SQLSTATE for determining the exception class and append
// the calling function name.

memcpy(sqlstate, sqlstateT, sizeof(sqlstate[0]) * _countof(sqlstate));

pMsg = PyUnicode_FromFormat("[%V] %V (%ld) (%s)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError, szFunction);
if (pMsg == 0)
// This is the first error message, so save the SQLSTATE for determining the
// exception class and append the calling function name.
CopySqlState(sqlstateT, sqlstate);
msg = PyUnicode_FromFormat("[%s] %V (%ld) (%s)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError, szFunction);
if (!msg)
return 0;
}
else
{
// This is not the first error message, so append to the existing one.
pMsgPart = PyString_FromFormat("; [%V] %V (%ld)", stateStr.Get(), "00000", msgStr.Get(), "(null)", (long)nNativeError);
if (pMsgPart == 0)
{
Py_XDECREF(pMsg);
return 0;
}
PyString_ConcatAndDel(&pMsg, pMsgPart);
Object more(PyUnicode_FromFormat("; [%s] %V (%ld)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError));
if (!more)
break; // Something went wrong, but we'll return the msg we have so far

Object both(PyUnicode_Concat(msg, more));
if (!both)
break;

msg = both.Detach();
}
}

Expand All @@ -289,20 +317,20 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
#endif
}

if (pMsg == 0)
if (!msg || PyUnicode_GetSize(msg.Get()) == 0)
{
// This only happens using unixODBC. (Haven't tried iODBC yet.) Either the driver or the driver manager is
// buggy and has signaled a fault without recording error information.
sqlstate[0] = '\0';
pMsg = PyString_FromString(DEFAULT_ERROR);
if (pMsg == 0)
msg = PyString_FromString(DEFAULT_ERROR);
if (!msg)
{
PyErr_NoMemory();
return 0;
}
}

return GetError(sqlstate, 0, pMsg);
return GetError(sqlstate, 0, msg.Detach());
}


Expand Down
2 changes: 1 addition & 1 deletion src/params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,7 @@ bool ExecuteMulti(Cursor* cur, PyObject* pSql, PyObject* paramArrayObj)
SQLGetStmtAttr(cur->hstmt, SQL_ATTR_APP_PARAM_DESC, &desc, 0, 0);
SQLSetDescField(desc, i + 1, SQL_DESC_TYPE, (SQLPOINTER)SQL_C_NUMERIC, 0);
SQLSetDescField(desc, i + 1, SQL_DESC_PRECISION, (SQLPOINTER)cur->paramInfos[i].ColumnSize, 0);
SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)cur->paramInfos[i].DecimalDigits, 0);
SQLSetDescField(desc, i + 1, SQL_DESC_SCALE, (SQLPOINTER)(uintptr_t)cur->paramInfos[i].DecimalDigits, 0);
SQLSetDescField(desc, i + 1, SQL_DESC_DATA_PTR, bindptr, 0);
}
bindptr += cur->paramInfos[i].BufferLength + sizeof(SQLLEN);
Expand Down
6 changes: 6 additions & 0 deletions tests2/sqlservertests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,12 @@ def test_prepare_cleanup(self):
self.cursor.execute("select top 1 name from sysobjects where name = ?", "bogus")
self.cursor.fetchone()

def test_exc_integrity(self):
"Make sure an IntegretyError is raised"
# This is really making sure we are properly encoding and comparing the SQLSTATEs.
self.cursor.execute("create table t1(s1 varchar(10) primary key)")
self.cursor.execute("insert into t1 values ('one')")
self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')")


def main():
Expand Down
32 changes: 26 additions & 6 deletions tests3/pgtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,32 @@

from __future__ import print_function

import sys, os, re, uuid
import uuid
import unittest
from decimal import Decimal
from testutils import *

_TESTSTR = '0123456789-abcdefghijklmnopqrstuvwxyz-'


def _generate_test_string(length):
"""
Returns a string of composed of `seed` to make a string `length` characters long.
To enhance performance, there are 3 ways data is read, based on the length of the value, so most data types are
tested with 3 lengths. This function helps us generate the test data.
To enhance performance, there are 3 ways data is read, based on the length of the value, so
most data types are tested with 3 lengths. This function helps us generate the test data.
We use a recognizable data set instead of a single character to make it less likely that "overlap" errors will
be hidden and to help us manually identify where a break occurs.
We use a recognizable data set instead of a single character to make it less likely that
"overlap" errors will be hidden and to help us manually identify where a break occurs.
"""
if length <= len(_TESTSTR):
return _TESTSTR[:length]

c = int((length + len(_TESTSTR)-1) / len(_TESTSTR))
c = int((length + len(_TESTSTR) - 1) / len(_TESTSTR))
v = _TESTSTR * c
return v[:length]


class PGTestCase(unittest.TestCase):

INTEGERS = [ -1, 0, 1, 0x7FFFFFFF ]
Expand Down Expand Up @@ -190,6 +192,16 @@ def test_varchar_many(self):
self.assertEqual(v2, row.c2)
self.assertEqual(v3, row.c3)

def test_chinese(self):
v = '我的'
self.cursor.execute("SELECT N'我的' AS name")
row = self.cursor.fetchone()
self.assertEqual(row[0], v)

self.cursor.execute("SELECT N'我的' AS name")
rows = self.cursor.fetchall()
self.assertEqual(rows[0][0], v)

#
# bytea
#
Expand Down Expand Up @@ -482,6 +494,14 @@ def test_autocommit(self):
othercnxn.autocommit = False
self.assertEqual(othercnxn.autocommit, False)

def test_exc_integrity(self):
"Make sure an IntegretyError is raised"
# This is really making sure we are properly encoding and comparing the SQLSTATEs.
self.cursor.execute("create table t1(s1 varchar(10) primary key)")
self.cursor.execute("insert into t1 values ('one')")
self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')")


def test_cnxn_set_attr_before(self):
# I don't have a getattr right now since I don't have a table telling me what kind of
# value to expect. For now just make sure it doesn't crash.
Expand Down
7 changes: 7 additions & 0 deletions tests3/sqlservertests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,13 @@ def test_decode_meta(self):
self.cursor.execute('select a as "Tipología" from t1')
self.assertEqual(self.cursor.description[0][0], "Tipología")

def test_exc_integrity(self):
"Make sure an IntegretyError is raised"
# This is really making sure we are properly encoding and comparing the SQLSTATEs.
self.cursor.execute("create table t1(s1 varchar(10) primary key)")
self.cursor.execute("insert into t1 values ('one')")
self.failUnlessRaises(pyodbc.IntegrityError, self.cursor.execute, "insert into t1 values ('one')")


def main():
from optparse import OptionParser
Expand Down

0 comments on commit 8277073

Please sign in to comment.