diff --git a/src/connection.cpp b/src/connection.cpp index cc936914..3cd0418d 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -1259,6 +1259,45 @@ static PyObject* Connection_conv_remove(PyObject* self, PyObject* args) Py_RETURN_NONE; } +static char conv_get_doc[] = + "get_output_converter(sqltype) --> \n" + "\n" + "Get the output converter function that was registered with\n" + "add_output_converter. It is safe to call if no converter is\n" + "registered for the type (returns None).\n" + "\n" + "sqltype\n" + " The integer SQL type value being converted, which can be one of the defined\n" + " standard constants (e.g. pyodbc.SQL_VARCHAR) or a database-specific value\n" + " (e.g. -151 for the SQL Server 2008 geometry data type).\n" + ; + +static PyObject* _get_converter(PyObject* self, SQLSMALLINT sqltype) +{ + Connection* cnxn = (Connection*)self; + + if (cnxn->conv_count) + { + for (int i = 0; i < cnxn->conv_count; i++) + { + if (cnxn->conv_types[i] == sqltype) + { + return cnxn->conv_funcs[i]; + } + } + } + Py_RETURN_NONE; +} + +static PyObject* Connection_conv_get(PyObject* self, PyObject* args) +{ + int sqltype; + if (!PyArg_ParseTuple(args, "i", &sqltype)) + return 0; + + return _get_converter(self, (SQLSMALLINT)sqltype); +} + static void NormalizeCodecName(const char* src, char* dest, size_t cbDest) { // Copies the codec name to dest, lowercasing it and replacing underscores with dashes. @@ -1564,6 +1603,7 @@ static struct PyMethodDef Connection_methods[] = { "getinfo", Connection_getinfo, METH_VARARGS, getinfo_doc }, { "add_output_converter", Connection_conv_add, METH_VARARGS, conv_add_doc }, { "remove_output_converter", Connection_conv_remove, METH_VARARGS, conv_remove_doc }, + { "get_output_converter", Connection_conv_get, METH_VARARGS, conv_get_doc }, { "clear_output_converters", Connection_conv_clear, METH_NOARGS, conv_clear_doc }, { "setdecoding", (PyCFunction)Connection_setdecoding, METH_VARARGS|METH_KEYWORDS, setdecoding_doc }, { "setencoding", (PyCFunction)Connection_setencoding, METH_VARARGS|METH_KEYWORDS, 0 }, diff --git a/tests3/sqlservertests.py b/tests3/sqlservertests.py index 0500db85..fcaedeee 100644 --- a/tests3/sqlservertests.py +++ b/tests3/sqlservertests.py @@ -1346,15 +1346,19 @@ def test_none_param(self): def test_output_conversion(self): - def convert(value): + def convert1(value): # The value is the raw bytes (as a bytes object) read from the # database. We'll simply add an X at the beginning at the end. return 'X' + value.decode('latin1') + 'X' + def convert2(value): + # Same as above, but add a Y at the beginning at the end. + return 'Y' + value.decode('latin1') + 'Y' + self.cursor.execute("create table t1(n int, v varchar(10))") self.cursor.execute("insert into t1 values (1, '123.45')") - self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert) + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1) value = self.cursor.execute("select v from t1").fetchone()[0] self.assertEqual(value, 'X123.45X') @@ -1364,7 +1368,7 @@ def convert(value): self.assertEqual(value, '123.45') # Same but clear using remove_output_converter. - self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert) + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1) value = self.cursor.execute("select v from t1").fetchone()[0] self.assertEqual(value, 'X123.45X') @@ -1372,15 +1376,44 @@ def convert(value): value = self.cursor.execute("select v from t1").fetchone()[0] self.assertEqual(value, '123.45') - # And lastly, clear by passing None for the converter. - self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert) + # Clear via add_output_converter, passing None for the converter function. + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1) value = self.cursor.execute("select v from t1").fetchone()[0] self.assertEqual(value, 'X123.45X') self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, None) value = self.cursor.execute("select v from t1").fetchone()[0] self.assertEqual(value, '123.45') - + + # retrieve and temporarily replace converter (get_output_converter) + # + # case_1: converter already registered + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1) + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, 'X123.45X') + prev_converter = self.cnxn.get_output_converter(pyodbc.SQL_VARCHAR) + self.assertNotEqual(prev_converter, None) + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert2) + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, 'Y123.45Y') + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, prev_converter) + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, 'X123.45X') + # + # case_2: no converter already registered + self.cnxn.clear_output_converters() + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, '123.45') + prev_converter = self.cnxn.get_output_converter(pyodbc.SQL_VARCHAR) + self.assertEqual(prev_converter, None) + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert2) + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, 'Y123.45Y') + self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, prev_converter) + value = self.cursor.execute("select v from t1").fetchone()[0] + self.assertEqual(value, '123.45') + + def test_too_large(self): """Ensure error raised if insert fails due to truncation""" value = 'x' * 1000