Skip to content

Commit

Permalink
bpo-44859: Improve error handling in sqlite3 and and raise more accur…
Browse files Browse the repository at this point in the history
…ate exceptions. (GH-27654)

* MemoryError is now raised instead of sqlite3.Warning when
  memory is not enough for encoding a statement to UTF-8
  in Connection.__call__() and Cursor.execute().
* UnicodEncodeError is now raised instead of sqlite3.Warning when
  the statement contains surrogate characters
  in Connection.__call__() and Cursor.execute().
* TypeError is now raised instead of ValueError for non-string
  script argument in Cursor.executescript().
* ValueError is now raised for script containing the null
  character instead of truncating it in Cursor.executescript().
* Correctly handle exceptions raised when getting boolean value
  of the result of the progress handler.
* Add many tests covering different corner cases.

Co-authored-by: Erlend Egeberg Aasland <erlend.aasland@innova.no>
  • Loading branch information
serhiy-storchaka and Erlend Egeberg Aasland authored Aug 8, 2021
1 parent ebecffd commit 0eec627
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 52 deletions.
33 changes: 30 additions & 3 deletions Lib/sqlite3/test/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import threading
import unittest

from test.support import check_disallow_instantiation, threading_helper
from test.support import check_disallow_instantiation, threading_helper, bigmemtest
from test.support.os_helper import TESTFN, unlink


Expand Down Expand Up @@ -758,9 +758,35 @@ def test_script_error_normal(self):
def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError) as cm:
with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
self.assertEqual(str(cm.exception), 'script argument must be unicode.')

def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError):
cur.executescript("""
create table a(i);\0
insert into a(i) values (5);
""")

def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError):
cur.executescript("""
create table a(s);
insert into a(s) values ('\ud8ff');
""")

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_cursor_executescript_too_large_script(self, maxsize):
con = sqlite.connect(":memory:")
cur = con.cursor()
for size in 2**31-1, 2**31:
with self.assertRaises(sqlite.DataError):
cur.executescript("create table a(s);".ljust(size))

def test_connection_execute(self):
con = sqlite.connect(":memory:")
Expand Down Expand Up @@ -969,6 +995,7 @@ def suite():
CursorTests,
ExtensionTests,
ModuleTests,
OpenTests,
SqliteOnConflictTests,
ThreadTests,
UninitialisedConnectionTests,
Expand Down
29 changes: 27 additions & 2 deletions Lib/sqlite3/test/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sqlite3 as sqlite

from test.support.os_helper import TESTFN, unlink

from .userfunctions import with_tracebacks

class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self):
Expand Down Expand Up @@ -145,7 +145,6 @@ def progress():
""")
self.assertTrue(progress_calls)


def test_opcode_count(self):
"""
Test that the opcode argument is respected.
Expand Down Expand Up @@ -198,6 +197,32 @@ def progress():
con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared")

@with_tracebacks(['bad_progress', 'ZeroDivisionError'])
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")

@with_tracebacks(['__bool__', 'ZeroDivisionError'])
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")


class TraceCallbackTests(unittest.TestCase):
def test_trace_callback_used(self):
"""
Expand Down
23 changes: 22 additions & 1 deletion Lib/sqlite3/test/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# 3. This notice may not be removed or altered from any source distribution.

import datetime
import sys
import unittest
import sqlite3 as sqlite
import weakref
Expand Down Expand Up @@ -273,7 +274,7 @@ def test_connection_call(self):
Call a connection with a non-string SQL request: check error handling
of the statement constructor.
"""
self.assertRaises(TypeError, self.con, 1)
self.assertRaises(TypeError, self.con, b"select 1")

def test_collation(self):
def collation_cb(a, b):
Expand Down Expand Up @@ -344,6 +345,26 @@ def test_null_character(self):
self.assertRaises(ValueError, cur.execute, " \0select 2")
self.assertRaises(ValueError, cur.execute, "select 2\0")

def test_surrogates(self):
con = sqlite.connect(":memory:")
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor()
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'")

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_large_sql(self, maxsize):
# Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX.
for size in (2**31, 2**31-2):
con = sqlite.connect(":memory:")
sql = "select 1".ljust(size)
self.assertRaises(sqlite.DataError, con, sql)
cur = con.cursor()
self.assertRaises(sqlite.DataError, cur.execute, sql)
del sql

def test_commit_cursor_reset(self):
"""
Connection.commit() did reset cursors, which made sqlite3
Expand Down
52 changes: 50 additions & 2 deletions Lib/sqlite3/test/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import datetime
import unittest
import sqlite3 as sqlite
import sys
try:
import zlib
except ImportError:
zlib = None

from test import support


class SqliteTypeTests(unittest.TestCase):
def setUp(self):
Expand All @@ -45,14 +48,20 @@ def test_string(self):
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")

def test_string_with_null_character(self):
self.cur.execute("insert into test(s) values (?)", ("a\0b",))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertEqual(row[0], "a\0b")

def test_small_int(self):
self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], 42)

def test_large_int(self):
num = 2**40
num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
Expand All @@ -78,6 +87,45 @@ def test_unicode_execute(self):
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")

def test_too_large_int(self):
for value in 2**63, -2**63-1, 2**64:
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(i) values (?)", (value,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertIsNone(row)

def test_string_with_surrogates(self):
for value in 0xd8ff, 0xdcff:
with self.assertRaises(UnicodeEncodeError):
self.cur.execute("insert into test(s) values (?)", (chr(value),))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertIsNone(row)

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_too_large_string(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_too_large_blob(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)


class DeclTypesTests(unittest.TestCase):
class Foo:
def __init__(self, _val):
Expand Down Expand Up @@ -163,7 +211,7 @@ def test_small_int(self):

def test_large_int(self):
# default
num = 2**40
num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
Expand Down
48 changes: 37 additions & 11 deletions Lib/sqlite3/test/userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,37 @@
from test.support import bigmemtest


def with_tracebacks(strings):
def with_tracebacks(strings, traceback=True):
"""Convenience decorator for testing callback tracebacks."""
strings.append('Traceback')
if traceback:
strings.append('Traceback')

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled.
sqlite.enable_callback_tracebacks(True)
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
with check_tracebacks(self, strings):
func(self, *args, **kwargs)
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)

# Then run the test with traceback disabled.
sqlite.enable_callback_tracebacks(False)
func(self, *args, **kwargs)
return wrapper
return decorator

@contextlib.contextmanager
def check_tracebacks(self, strings):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
finally:
sqlite.enable_callback_tracebacks(False)

def func_returntext():
return "foo"
def func_returntextwithnull():
Expand Down Expand Up @@ -408,9 +417,26 @@ def md5sum(t):
del x,y
gc.collect()

def test_func_return_too_large_int(self):
cur = self.con.cursor()
for value in 2**63, -2**63-1, 2**64:
self.con.create_function("largeint", 0, lambda value=value: value)
with check_tracebacks(self, ['OverflowError']):
with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()")

def test_func_return_text_with_surrogates(self):
cur = self.con.cursor()
self.con.create_function("pychr", 1, chr)
for value in 0xd8ff, 0xdcff:
with check_tracebacks(self,
['UnicodeEncodeError', 'surrogates not allowed']):
with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,))

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_large_text(self, size):
def test_func_return_too_large_text(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largetext", 0, lambda size=size: "b" * size)
Expand All @@ -419,7 +445,7 @@ def test_large_text(self, size):

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=2, dry_run=False)
def test_large_blob(self, size):
def test_func_return_too_large_blob(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Improve error handling in :mod:`sqlite3` and raise more accurate exceptions.

* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``.
* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``.
* Correctly handle exceptions raised when getting boolean value of the result of the progress handler.
* Add many tests covering different corner cases.
31 changes: 30 additions & 1 deletion Modules/_sqlite/clinic/cursor.c.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__,
#define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \
{"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__},

static PyObject *
pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
const char *sql_script);

static PyObject *
pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg)
{
PyObject *return_value = NULL;
const char *sql_script;

if (!PyUnicode_Check(arg)) {
_PyArg_BadArgument("executescript", "argument", "str", arg);
goto exit;
}
Py_ssize_t sql_script_length;
sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length);
if (sql_script == NULL) {
goto exit;
}
if (strlen(sql_script) != (size_t)sql_script_length) {
PyErr_SetString(PyExc_ValueError, "embedded null character");
goto exit;
}
return_value = pysqlite_cursor_executescript_impl(self, sql_script);

exit:
return return_value;
}

PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__,
"fetchone($self, /)\n"
"--\n"
Expand Down Expand Up @@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const
exit:
return return_value;
}
/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/
/*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/
Loading

0 comments on commit 0eec627

Please sign in to comment.