diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 408f9945f2c970..5d7e5bba05bc45 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -26,7 +26,7 @@ import threading import unittest -from test.support import check_disallow_instantiation, threading_helper +from test.support import check_disallow_instantiation, threading_helper, bigmemtest from test.support.os_helper import TESTFN, unlink @@ -758,9 +758,35 @@ def test_script_error_normal(self): def test_cursor_executescript_as_bytes(self): con = sqlite.connect(":memory:") cur = con.cursor() - with self.assertRaises(ValueError) as cm: + with self.assertRaises(TypeError): cur.executescript(b"create table test(foo); insert into test(foo) values (5);") - self.assertEqual(str(cm.exception), 'script argument must be unicode.') + + def test_cursor_executescript_with_null_characters(self): + con = sqlite.connect(":memory:") + cur = con.cursor() + with self.assertRaises(ValueError): + cur.executescript(""" + create table a(i);\0 + insert into a(i) values (5); + """) + + def test_cursor_executescript_with_surrogates(self): + con = sqlite.connect(":memory:") + cur = con.cursor() + with self.assertRaises(UnicodeEncodeError): + cur.executescript(""" + create table a(s); + insert into a(s) values ('\ud8ff'); + """) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @bigmemtest(size=2**31, memuse=3, dry_run=False) + def test_cursor_executescript_too_large_script(self, maxsize): + con = sqlite.connect(":memory:") + cur = con.cursor() + for size in 2**31-1, 2**31: + with self.assertRaises(sqlite.DataError): + cur.executescript("create table a(s);".ljust(size)) def test_connection_execute(self): con = sqlite.connect(":memory:") @@ -969,6 +995,7 @@ def suite(): CursorTests, ExtensionTests, ModuleTests, + OpenTests, SqliteOnConflictTests, ThreadTests, UninitialisedConnectionTests, diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index 1be6d380abd20a..43e3810d13df18 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -24,7 +24,7 @@ import sqlite3 as sqlite from test.support.os_helper import TESTFN, unlink - +from .userfunctions import with_tracebacks class CollationTests(unittest.TestCase): def test_create_collation_not_string(self): @@ -145,7 +145,6 @@ def progress(): """) self.assertTrue(progress_calls) - def test_opcode_count(self): """ Test that the opcode argument is respected. @@ -198,6 +197,32 @@ def progress(): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") + @with_tracebacks(['bad_progress', 'ZeroDivisionError']) + def test_error_in_progress_handler(self): + con = sqlite.connect(":memory:") + def bad_progress(): + 1 / 0 + con.set_progress_handler(bad_progress, 1) + with self.assertRaises(sqlite.OperationalError): + con.execute(""" + create table foo(a, b) + """) + + @with_tracebacks(['__bool__', 'ZeroDivisionError']) + def test_error_in_progress_handler_result(self): + con = sqlite.connect(":memory:") + class BadBool: + def __bool__(self): + 1 / 0 + def bad_progress(): + return BadBool() + con.set_progress_handler(bad_progress, 1) + with self.assertRaises(sqlite.OperationalError): + con.execute(""" + create table foo(a, b) + """) + + class TraceCallbackTests(unittest.TestCase): def test_trace_callback_used(self): """ diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 6c093d7c2c36e0..ddf36e71819445 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -21,6 +21,7 @@ # 3. This notice may not be removed or altered from any source distribution. import datetime +import sys import unittest import sqlite3 as sqlite import weakref @@ -273,7 +274,7 @@ def test_connection_call(self): Call a connection with a non-string SQL request: check error handling of the statement constructor. """ - self.assertRaises(TypeError, self.con, 1) + self.assertRaises(TypeError, self.con, b"select 1") def test_collation(self): def collation_cb(a, b): @@ -344,6 +345,26 @@ def test_null_character(self): self.assertRaises(ValueError, cur.execute, " \0select 2") self.assertRaises(ValueError, cur.execute, "select 2\0") + def test_surrogates(self): + con = sqlite.connect(":memory:") + self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") + self.assertRaises(UnicodeEncodeError, con, "select '\udcff'") + cur = con.cursor() + self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'") + self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'") + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=4, dry_run=False) + def test_large_sql(self, maxsize): + # Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX. + for size in (2**31, 2**31-2): + con = sqlite.connect(":memory:") + sql = "select 1".ljust(size) + self.assertRaises(sqlite.DataError, con, sql) + cur = con.cursor() + self.assertRaises(sqlite.DataError, cur.execute, sql) + del sql + def test_commit_cursor_reset(self): """ Connection.commit() did reset cursors, which made sqlite3 diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 4f0e4f6d268392..b8926ffee22e87 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -23,11 +23,14 @@ import datetime import unittest import sqlite3 as sqlite +import sys try: import zlib except ImportError: zlib = None +from test import support + class SqliteTypeTests(unittest.TestCase): def setUp(self): @@ -45,6 +48,12 @@ def test_string(self): row = self.cur.fetchone() self.assertEqual(row[0], "Österreich") + def test_string_with_null_character(self): + self.cur.execute("insert into test(s) values (?)", ("a\0b",)) + self.cur.execute("select s from test") + row = self.cur.fetchone() + self.assertEqual(row[0], "a\0b") + def test_small_int(self): self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("select i from test") @@ -52,7 +61,7 @@ def test_small_int(self): self.assertEqual(row[0], 42) def test_large_int(self): - num = 2**40 + num = 123456789123456789 self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("select i from test") row = self.cur.fetchone() @@ -78,6 +87,45 @@ def test_unicode_execute(self): row = self.cur.fetchone() self.assertEqual(row[0], "Österreich") + def test_too_large_int(self): + for value in 2**63, -2**63-1, 2**64: + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(i) values (?)", (value,)) + self.cur.execute("select i from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + def test_string_with_surrogates(self): + for value in 0xd8ff, 0xdcff: + with self.assertRaises(UnicodeEncodeError): + self.cur.execute("insert into test(s) values (?)", (chr(value),)) + self.cur.execute("select s from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=4, dry_run=False) + def test_too_large_string(self, maxsize): + with self.assertRaises(sqlite.InterfaceError): + self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),)) + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),)) + self.cur.execute("select 1 from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @support.bigmemtest(size=2**31, memuse=3, dry_run=False) + def test_too_large_blob(self, maxsize): + with self.assertRaises(sqlite.InterfaceError): + self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),)) + with self.assertRaises(OverflowError): + self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),)) + self.cur.execute("select 1 from test") + row = self.cur.fetchone() + self.assertIsNone(row) + + class DeclTypesTests(unittest.TestCase): class Foo: def __init__(self, _val): @@ -163,7 +211,7 @@ def test_small_int(self): def test_large_int(self): # default - num = 2**40 + num = 123456789123456789 self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("select i from test") row = self.cur.fetchone() diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 9681dbdde2b092..b4d5181777ebdf 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -33,28 +33,37 @@ from test.support import bigmemtest -def with_tracebacks(strings): +def with_tracebacks(strings, traceback=True): """Convenience decorator for testing callback tracebacks.""" - strings.append('Traceback') + if traceback: + strings.append('Traceback') def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): # First, run the test with traceback enabled. - sqlite.enable_callback_tracebacks(True) - buf = io.StringIO() - with contextlib.redirect_stderr(buf): + with check_tracebacks(self, strings): func(self, *args, **kwargs) - tb = buf.getvalue() - for s in strings: - self.assertIn(s, tb) # Then run the test with traceback disabled. - sqlite.enable_callback_tracebacks(False) func(self, *args, **kwargs) return wrapper return decorator +@contextlib.contextmanager +def check_tracebacks(self, strings): + """Convenience context manager for testing callback tracebacks.""" + sqlite.enable_callback_tracebacks(True) + try: + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + yield + tb = buf.getvalue() + for s in strings: + self.assertIn(s, tb) + finally: + sqlite.enable_callback_tracebacks(False) + def func_returntext(): return "foo" def func_returntextwithnull(): @@ -408,9 +417,26 @@ def md5sum(t): del x,y gc.collect() + def test_func_return_too_large_int(self): + cur = self.con.cursor() + for value in 2**63, -2**63-1, 2**64: + self.con.create_function("largeint", 0, lambda value=value: value) + with check_tracebacks(self, ['OverflowError']): + with self.assertRaises(sqlite.DataError): + cur.execute("select largeint()") + + def test_func_return_text_with_surrogates(self): + cur = self.con.cursor() + self.con.create_function("pychr", 1, chr) + for value in 0xd8ff, 0xdcff: + with check_tracebacks(self, + ['UnicodeEncodeError', 'surrogates not allowed']): + with self.assertRaises(sqlite.OperationalError): + cur.execute("select pychr(?)", (value,)) + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @bigmemtest(size=2**31, memuse=3, dry_run=False) - def test_large_text(self, size): + def test_func_return_too_large_text(self, size): cur = self.con.cursor() for size in 2**31-1, 2**31: self.con.create_function("largetext", 0, lambda size=size: "b" * size) @@ -419,7 +445,7 @@ def test_large_text(self, size): @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @bigmemtest(size=2**31, memuse=2, dry_run=False) - def test_large_blob(self, size): + def test_func_return_too_large_blob(self, size): cur = self.con.cursor() for size in 2**31-1, 2**31: self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) diff --git a/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst new file mode 100644 index 00000000000000..ec9f774d66b8c4 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst @@ -0,0 +1,8 @@ +Improve error handling in :mod:`sqlite3` and raise more accurate exceptions. + +* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``. +* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``. +* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``. +* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``. +* Correctly handle exceptions raised when getting boolean value of the result of the progress handler. +* Add many tests covering different corner cases. diff --git a/Modules/_sqlite/clinic/cursor.c.h b/Modules/_sqlite/clinic/cursor.c.h index d2c453b38b4b9e..07e15870146cf7 100644 --- a/Modules/_sqlite/clinic/cursor.c.h +++ b/Modules/_sqlite/clinic/cursor.c.h @@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__, #define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \ {"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__}, +static PyObject * +pysqlite_cursor_executescript_impl(pysqlite_Cursor *self, + const char *sql_script); + +static PyObject * +pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg) +{ + PyObject *return_value = NULL; + const char *sql_script; + + if (!PyUnicode_Check(arg)) { + _PyArg_BadArgument("executescript", "argument", "str", arg); + goto exit; + } + Py_ssize_t sql_script_length; + sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length); + if (sql_script == NULL) { + goto exit; + } + if (strlen(sql_script) != (size_t)sql_script_length) { + PyErr_SetString(PyExc_ValueError, "embedded null character"); + goto exit; + } + return_value = pysqlite_cursor_executescript_impl(self, sql_script); + +exit: + return return_value; +} + PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__, "fetchone($self, /)\n" "--\n" @@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const exit: return return_value; } -/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/ +/*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 0dab3e85160e82..67160c4c449aa1 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -997,6 +997,14 @@ static int _progress_handler(void* user_arg) ret = _PyObject_CallNoArg((PyObject*)user_arg); if (!ret) { + /* abort query if error occurred */ + rc = -1; + } + else { + rc = PyObject_IsTrue(ret); + Py_DECREF(ret); + } + if (rc < 0) { pysqlite_state *state = pysqlite_get_state(NULL); if (state->enable_callback_tracebacks) { PyErr_Print(); @@ -1004,12 +1012,6 @@ static int _progress_handler(void* user_arg) else { PyErr_Clear(); } - - /* abort query if error occurred */ - rc = 1; - } else { - rc = (int)PyObject_IsTrue(ret); - Py_DECREF(ret); } PyGILState_Release(gilstate); diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 2f4494690f9557..7308f3062da4b9 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -728,21 +728,21 @@ pysqlite_cursor_executemany_impl(pysqlite_Cursor *self, PyObject *sql, /*[clinic input] _sqlite3.Cursor.executescript as pysqlite_cursor_executescript - sql_script as script_obj: object + sql_script: str / Executes multiple SQL statements at once. Non-standard. [clinic start generated code]*/ static PyObject * -pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) -/*[clinic end generated code: output=115a8132b0f200fe input=ba3ec59df205e362]*/ +pysqlite_cursor_executescript_impl(pysqlite_Cursor *self, + const char *sql_script) +/*[clinic end generated code: output=8fd726dde1c65164 input=1ac0693dc8db02a8]*/ { _Py_IDENTIFIER(commit); - const char* script_cstr; sqlite3_stmt* statement; int rc; - Py_ssize_t sql_len; + size_t sql_len; PyObject* result; if (!check_cursor(self)) { @@ -751,21 +751,12 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) self->reset = 0; - if (PyUnicode_Check(script_obj)) { - script_cstr = PyUnicode_AsUTF8AndSize(script_obj, &sql_len); - if (!script_cstr) { - return NULL; - } - - int max_length = sqlite3_limit(self->connection->db, - SQLITE_LIMIT_LENGTH, -1); - if (sql_len >= max_length) { - PyErr_SetString(self->connection->DataError, - "query string is too large"); - return NULL; - } - } else { - PyErr_SetString(PyExc_ValueError, "script argument must be unicode."); + sql_len = strlen(sql_script); + int max_length = sqlite3_limit(self->connection->db, + SQLITE_LIMIT_LENGTH, -1); + if (sql_len >= (unsigned)max_length) { + PyErr_SetString(self->connection->DataError, + "query string is too large"); return NULL; } @@ -782,7 +773,7 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare_v2(self->connection->db, - script_cstr, + sql_script, (int)sql_len + 1, &statement, &tail); @@ -816,8 +807,8 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) if (*tail == (char)0) { break; } - sql_len -= (tail - script_cstr); - script_cstr = tail; + sql_len -= (tail - sql_script); + sql_script = tail; } error: diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 983df2d50c975d..2d5c72d13b7edb 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -56,9 +56,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) Py_ssize_t size; const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size); if (sql_cstr == NULL) { - PyErr_Format(connection->Warning, - "SQL is of wrong type ('%s'). Must be string.", - Py_TYPE(sql)->tp_name); return NULL; }