Skip to content

Commit

Permalink
add get_output_converter method (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
gordthompson committed Dec 20, 2018
1 parent 61cd5ef commit 6b15677
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 6 deletions.
40 changes: 40 additions & 0 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,45 @@ static PyObject* Connection_conv_remove(PyObject* self, PyObject* args)
Py_RETURN_NONE;
}

static char conv_get_doc[] =
"get_output_converter(sqltype) --> <class 'function'>\n"
"\n"
"Get the output converter function that was registered with\n"
"add_output_converter. It is safe to call if no converter is\n"
"registered for the type (returns None).\n"
"\n"
"sqltype\n"
" The integer SQL type value being converted, which can be one of the defined\n"
" standard constants (e.g. pyodbc.SQL_VARCHAR) or a database-specific value\n"
" (e.g. -151 for the SQL Server 2008 geometry data type).\n"
;

static PyObject* _get_converter(PyObject* self, SQLSMALLINT sqltype)
{
Connection* cnxn = (Connection*)self;

if (cnxn->conv_count)
{
for (int i = 0; i < cnxn->conv_count; i++)
{
if (cnxn->conv_types[i] == sqltype)
{
return cnxn->conv_funcs[i];
}
}
}
Py_RETURN_NONE;
}

static PyObject* Connection_conv_get(PyObject* self, PyObject* args)
{
int sqltype;
if (!PyArg_ParseTuple(args, "i", &sqltype))
return 0;

return _get_converter(self, (SQLSMALLINT)sqltype);
}

static void NormalizeCodecName(const char* src, char* dest, size_t cbDest)
{
// Copies the codec name to dest, lowercasing it and replacing underscores with dashes.
Expand Down Expand Up @@ -1564,6 +1603,7 @@ static struct PyMethodDef Connection_methods[] =
{ "getinfo", Connection_getinfo, METH_VARARGS, getinfo_doc },
{ "add_output_converter", Connection_conv_add, METH_VARARGS, conv_add_doc },
{ "remove_output_converter", Connection_conv_remove, METH_VARARGS, conv_remove_doc },
{ "get_output_converter", Connection_conv_get, METH_VARARGS, conv_get_doc },
{ "clear_output_converters", Connection_conv_clear, METH_NOARGS, conv_clear_doc },
{ "setdecoding", (PyCFunction)Connection_setdecoding, METH_VARARGS|METH_KEYWORDS, setdecoding_doc },
{ "setencoding", (PyCFunction)Connection_setencoding, METH_VARARGS|METH_KEYWORDS, 0 },
Expand Down
45 changes: 39 additions & 6 deletions tests3/sqlservertests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,15 +1346,19 @@ def test_none_param(self):


def test_output_conversion(self):
def convert(value):
def convert1(value):
# The value is the raw bytes (as a bytes object) read from the
# database. We'll simply add an X at the beginning at the end.
return 'X' + value.decode('latin1') + 'X'

def convert2(value):
# Same as above, but add a Y at the beginning at the end.
return 'Y' + value.decode('latin1') + 'Y'

self.cursor.execute("create table t1(n int, v varchar(10))")
self.cursor.execute("insert into t1 values (1, '123.45')")

self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert)
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'X123.45X')

Expand All @@ -1364,23 +1368,52 @@ def convert(value):
self.assertEqual(value, '123.45')

# Same but clear using remove_output_converter.
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert)
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'X123.45X')

self.cnxn.remove_output_converter(pyodbc.SQL_VARCHAR)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, '123.45')

# And lastly, clear by passing None for the converter.
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert)
# Clear via add_output_converter, passing None for the converter function.
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'X123.45X')

self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, None)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, '123.45')


# retrieve and temporarily replace converter (get_output_converter)
#
# case_1: converter already registered
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert1)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'X123.45X')
prev_converter = self.cnxn.get_output_converter(pyodbc.SQL_VARCHAR)
self.assertNotEqual(prev_converter, None)
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert2)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'Y123.45Y')
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, prev_converter)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'X123.45X')
#
# case_2: no converter already registered
self.cnxn.clear_output_converters()
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, '123.45')
prev_converter = self.cnxn.get_output_converter(pyodbc.SQL_VARCHAR)
self.assertEqual(prev_converter, None)
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert2)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, 'Y123.45Y')
self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, prev_converter)
value = self.cursor.execute("select v from t1").fetchone()[0]
self.assertEqual(value, '123.45')


def test_too_large(self):
"""Ensure error raised if insert fails due to truncation"""
value = 'x' * 1000
Expand Down

0 comments on commit 6b15677

Please sign in to comment.