Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic buffers in GetDiagRecs and GetErrorFromHandle and prevent buffer overflow on errors longer than 1024 characters #881

Closed
Closed
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def main():

'ext_modules': [Extension('pyodbc', sorted(files), **settings)],

'data_files': [
('', ['src/pyodbc.pyi']) # places pyodbc.pyi alongside pyodbc.py in site-packages
],

'license': 'MIT',

'classifiers': ['Development Status :: 5 - Production/Stable',
Expand Down
40 changes: 33 additions & 7 deletions src/cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,17 +568,22 @@ static int GetDiagRecs(Cursor* cur)
{
// Retrieves all diagnostic records from the cursor and assigns them to the "messages" attribute.

PyObject* msg_list;

SQLSMALLINT iRecNumber = 1;
PyObject* msg_list; // the "messages" as a Python list of diagnostic records

SQLSMALLINT iRecNumber = 1; // the index of the diagnostic records (1-based)
ODBCCHAR cSQLState[6]; // five-character SQLSTATE code (plus terminating NULL)
SQLINTEGER iNativeError;
ODBCCHAR cMessageText[10240]; // PRINT statements can be large, hopefully 10K bytes will be enough
SQLSMALLINT iMessageLen = 1023;
ODBCCHAR *cMessageText = (ODBCCHAR*) pyodbc_malloc((iMessageLen + 1) * sizeof(ODBCCHAR));
SQLSMALLINT iTextLength;

SQLRETURN ret;
char sqlstate_ascii[6] = ""; // ASCII version of the SQLState
char sqlstate_ascii[6] = ""; // ASCII version of the SQLState

if (!cMessageText) {
PyErr_NoMemory();
return 0;
}

msg_list = PyList_New(0);
if (!msg_list)
Expand All @@ -594,12 +599,30 @@ static int GetDiagRecs(Cursor* cur)
Py_BEGIN_ALLOW_THREADS
ret = SQLGetDiagRecW(
SQL_HANDLE_STMT, cur->hstmt, iRecNumber, (SQLWCHAR*)cSQLState, &iNativeError,
(SQLWCHAR*)cMessageText, (short)(_countof(cMessageText)-1), &iTextLength
(SQLWCHAR*)cMessageText, iMessageLen, &iTextLength
);
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
break;

// If needed, allocate a bigger error message buffer and retry.
if (iTextLength > iMessageLen - 1) {
iMessageLen = iTextLength + 1;
if (!pyodbc_realloc((BYTE**) &cMessageText, (iMessageLen + 1) * sizeof(ODBCCHAR))) {
pyodbc_free(cMessageText);
PyErr_NoMemory();
return 0;
}
Py_BEGIN_ALLOW_THREADS
ret = SQLGetDiagRecW(
SQL_HANDLE_STMT, cur->hstmt, iRecNumber, (SQLWCHAR*)cSQLState, &iNativeError,
(SQLWCHAR*)cMessageText, iMessageLen, &iTextLength
);
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
break;
}

cSQLState[5] = 0; // Not always NULL terminated (MS Access)
CopySqlState(cSQLState, sqlstate_ascii);
PyObject* msg_class = PyUnicode_FromFormat("[%s] (%ld)", sqlstate_ascii, (long)iNativeError);
Expand All @@ -616,7 +639,7 @@ static int GetDiagRecs(Cursor* cur)
msg_value = PyBytes_FromStringAndSize((char*)cMessageText, iTextLength * sizeof(ODBCCHAR));
}

PyObject* msg_tuple = PyTuple_New(2);
PyObject* msg_tuple = PyTuple_New(2); // the message as a Python tuple of class and value

if (msg_class && msg_value && msg_tuple)
{
Expand All @@ -635,9 +658,12 @@ static int GetDiagRecs(Cursor* cur)

iRecNumber++;
}
pyodbc_free(cMessageText);

Py_XDECREF(cur->messages);
cur->messages = msg_list; // cur->messages now owns the msg_list reference

return 0;
}


Expand Down
38 changes: 31 additions & 7 deletions src/errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,13 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
SQLSMALLINT cchMsg;

ODBCCHAR sqlstateT[6];
ODBCCHAR szMsg[1024];
SQLSMALLINT msgLen = 1023;
ODBCCHAR *szMsg = (ODBCCHAR*) pyodbc_malloc((msgLen + 1) * sizeof(ODBCCHAR));

if (!szMsg) {
PyErr_NoMemory();
return 0;
}

if (hstmt != SQL_NULL_HANDLE)
{
Expand Down Expand Up @@ -251,11 +257,26 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc

SQLRETURN ret;
Py_BEGIN_ALLOW_THREADS
ret = SQLGetDiagRecW(nHandleType, h, iRecord, (SQLWCHAR*)sqlstateT, &nNativeError, (SQLWCHAR*)szMsg, (short)(_countof(szMsg)-1), &cchMsg);
ret = SQLGetDiagRecW(nHandleType, h, iRecord, (SQLWCHAR*)sqlstateT, &nNativeError, (SQLWCHAR*)szMsg, msgLen, &cchMsg);
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
break;

// If needed, allocate a bigger error message buffer and retry.
if (cchMsg > msgLen - 1) {
msgLen = cchMsg + 1;
if (!pyodbc_realloc((BYTE**) &szMsg, (msgLen + 1) * sizeof(ODBCCHAR))) {
Comment on lines +267 to +268
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I intentionally did some defensive "double +1" here and in other places. A small cost to avoid potential overflow.

PyErr_NoMemory();
pyodbc_free(szMsg);
return 0;
}
Py_BEGIN_ALLOW_THREADS
ret = SQLGetDiagRecW(nHandleType, h, iRecord, (SQLWCHAR*)sqlstateT, &nNativeError, (SQLWCHAR*)szMsg, msgLen, &cchMsg);
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
break;
}

// Not always NULL terminated (MS Access)
sqlstateT[5] = 0;

Expand All @@ -272,8 +293,11 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
// exception class and append the calling function name.
CopySqlState(sqlstateT, sqlstate);
msg = PyUnicode_FromFormat("[%s] %V (%ld) (%s)", sqlstate, msgStr.Get(), "(null)", (long)nNativeError, szFunction);
if (!msg)
if (!msg) {
PyErr_NoMemory();
pyodbc_free(szMsg);
return 0;
}
}
else
{
Expand All @@ -298,6 +322,9 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc
#endif
}

// Raw message buffer not needed anymore
pyodbc_free(szMsg);

if (!msg || PyUnicode_GetSize(msg.Get()) == 0)
{
// This only happens using unixODBC. (Haven't tried iODBC yet.) Either the driver or the driver manager is
Expand All @@ -317,14 +344,11 @@ PyObject* GetErrorFromHandle(Connection *conn, const char* szFunction, HDBC hdbc

static bool GetSqlState(HSTMT hstmt, char* szSqlState)
{
SQLCHAR szMsg[300];
SQLSMALLINT cbMsg = (SQLSMALLINT)(_countof(szMsg) - 1);
SQLINTEGER nNative;
SQLSMALLINT cchMsg;
SQLRETURN ret;

Py_BEGIN_ALLOW_THREADS
ret = SQLGetDiagRec(SQL_HANDLE_STMT, hstmt, 1, (SQLCHAR*)szSqlState, &nNative, szMsg, cbMsg, &cchMsg);
ret = SQLGetDiagField(SQL_HANDLE_STMT, hstmt, 1, SQL_DIAG_SQLSTATE, (SQLCHAR*)szSqlState, 5, &cchMsg);
Py_END_ALLOW_THREADS
return SQL_SUCCEEDED(ret);
}
Expand Down
4 changes: 4 additions & 0 deletions src/pyodbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ inline void DebugTrace(const char* szFmt, ...) { UNUSED(szFmt); }
#define pyodbc_free free
// #endif

// issue #880: entry missing from iODBC sqltypes.h
#ifndef BYTE
typedef unsigned char BYTE;
#endif
bool pyodbc_realloc(BYTE** pp, size_t newlen);
// A wrapper around realloc with a safer interface. If it is successful, *pp is updated to the
// new pointer value. If not successful, it is not modified. (It is easy to forget and lose
Expand Down
Loading