Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44859: Improve error handling in sqlite3 and change some errors #27654

Merged
merged 2 commits into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions Lib/sqlite3/test/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import threading
import unittest

from test.support import check_disallow_instantiation, threading_helper
from test.support import check_disallow_instantiation, threading_helper, bigmemtest
from test.support.os_helper import TESTFN, unlink


Expand Down Expand Up @@ -758,9 +758,35 @@ def test_script_error_normal(self):
def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError) as cm:
with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
self.assertEqual(str(cm.exception), 'script argument must be unicode.')

def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError):
cur.executescript("""
create table a(i);\0
insert into a(i) values (5);
""")

def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError):
cur.executescript("""
create table a(s);
insert into a(s) values ('\ud8ff');
""")

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_cursor_executescript_too_large_script(self, maxsize):
con = sqlite.connect(":memory:")
cur = con.cursor()
for size in 2**31-1, 2**31:
with self.assertRaises(sqlite.DataError):
cur.executescript("create table a(s);".ljust(size))

def test_connection_execute(self):
con = sqlite.connect(":memory:")
Expand Down Expand Up @@ -969,6 +995,7 @@ def suite():
CursorTests,
ExtensionTests,
ModuleTests,
OpenTests,
SqliteOnConflictTests,
ThreadTests,
UninitialisedConnectionTests,
Expand Down
29 changes: 27 additions & 2 deletions Lib/sqlite3/test/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sqlite3 as sqlite

from test.support.os_helper import TESTFN, unlink

from .userfunctions import with_tracebacks

class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self):
Expand Down Expand Up @@ -145,7 +145,6 @@ def progress():
""")
self.assertTrue(progress_calls)


def test_opcode_count(self):
"""
Test that the opcode argument is respected.
Expand Down Expand Up @@ -198,6 +197,32 @@ def progress():
con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared")

@with_tracebacks(['bad_progress', 'ZeroDivisionError'])
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")

@with_tracebacks(['__bool__', 'ZeroDivisionError'])
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")


class TraceCallbackTests(unittest.TestCase):
def test_trace_callback_used(self):
"""
Expand Down
23 changes: 22 additions & 1 deletion Lib/sqlite3/test/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# 3. This notice may not be removed or altered from any source distribution.

import datetime
import sys
import unittest
import sqlite3 as sqlite
import weakref
Expand Down Expand Up @@ -273,7 +274,7 @@ def test_connection_call(self):
Call a connection with a non-string SQL request: check error handling
of the statement constructor.
"""
self.assertRaises(TypeError, self.con, 1)
self.assertRaises(TypeError, self.con, b"select 1")

def test_collation(self):
def collation_cb(a, b):
Expand Down Expand Up @@ -344,6 +345,26 @@ def test_null_character(self):
self.assertRaises(ValueError, cur.execute, " \0select 2")
self.assertRaises(ValueError, cur.execute, "select 2\0")

def test_surrogates(self):
con = sqlite.connect(":memory:")
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor()
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'")

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_large_sql(self, maxsize):
# Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX.
for size in (2**31, 2**31-2):
con = sqlite.connect(":memory:")
sql = "select 1".ljust(size)
self.assertRaises(sqlite.DataError, con, sql)
cur = con.cursor()
self.assertRaises(sqlite.DataError, cur.execute, sql)
del sql

def test_commit_cursor_reset(self):
"""
Connection.commit() did reset cursors, which made sqlite3
Expand Down
52 changes: 50 additions & 2 deletions Lib/sqlite3/test/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import datetime
import unittest
import sqlite3 as sqlite
import sys
try:
import zlib
except ImportError:
zlib = None

from test import support


class SqliteTypeTests(unittest.TestCase):
def setUp(self):
Expand All @@ -45,14 +48,20 @@ def test_string(self):
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")

def test_string_with_null_character(self):
self.cur.execute("insert into test(s) values (?)", ("a\0b",))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertEqual(row[0], "a\0b")

def test_small_int(self):
self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], 42)

def test_large_int(self):
num = 2**40
num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
Expand All @@ -78,6 +87,45 @@ def test_unicode_execute(self):
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")

def test_too_large_int(self):
for value in 2**63, -2**63-1, 2**64:
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(i) values (?)", (value,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertIsNone(row)

def test_string_with_surrogates(self):
for value in 0xd8ff, 0xdcff:
with self.assertRaises(UnicodeEncodeError):
self.cur.execute("insert into test(s) values (?)", (chr(value),))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertIsNone(row)

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_too_large_string(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_too_large_blob(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)


class DeclTypesTests(unittest.TestCase):
class Foo:
def __init__(self, _val):
Expand Down Expand Up @@ -163,7 +211,7 @@ def test_small_int(self):

def test_large_int(self):
# default
num = 2**40
num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
Expand Down
48 changes: 37 additions & 11 deletions Lib/sqlite3/test/userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,37 @@
from test.support import bigmemtest


def with_tracebacks(strings):
def with_tracebacks(strings, traceback=True):
"""Convenience decorator for testing callback tracebacks."""
strings.append('Traceback')
if traceback:
strings.append('Traceback')

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled.
sqlite.enable_callback_tracebacks(True)
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
with check_tracebacks(self, strings):
func(self, *args, **kwargs)
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)

# Then run the test with traceback disabled.
sqlite.enable_callback_tracebacks(False)
func(self, *args, **kwargs)
return wrapper
return decorator

@contextlib.contextmanager
def check_tracebacks(self, strings):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
finally:
sqlite.enable_callback_tracebacks(False)

def func_returntext():
return "foo"
def func_returntextwithnull():
Expand Down Expand Up @@ -408,9 +417,26 @@ def md5sum(t):
del x,y
gc.collect()

def test_func_return_too_large_int(self):
cur = self.con.cursor()
for value in 2**63, -2**63-1, 2**64:
self.con.create_function("largeint", 0, lambda value=value: value)
with check_tracebacks(self, ['OverflowError']):
with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()")

def test_func_return_text_with_surrogates(self):
cur = self.con.cursor()
self.con.create_function("pychr", 1, chr)
for value in 0xd8ff, 0xdcff:
with check_tracebacks(self,
['UnicodeEncodeError', 'surrogates not allowed']):
with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,))

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_large_text(self, size):
def test_func_return_too_large_text(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largetext", 0, lambda size=size: "b" * size)
Expand All @@ -419,7 +445,7 @@ def test_large_text(self, size):

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=2, dry_run=False)
def test_large_blob(self, size):
def test_func_return_too_large_blob(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Improve error handling in :mod:`sqlite3` and raise more accurate exceptions.

* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``.
* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``.
* Correctly handle exceptions raised when getting boolean value of the result of the progress handler.
* Add many tests covering different corner cases.
31 changes: 30 additions & 1 deletion Modules/_sqlite/clinic/cursor.c.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__,
#define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \
{"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__},

static PyObject *
pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
const char *sql_script);

static PyObject *
pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg)
{
PyObject *return_value = NULL;
const char *sql_script;

if (!PyUnicode_Check(arg)) {
_PyArg_BadArgument("executescript", "argument", "str", arg);
goto exit;
}
Py_ssize_t sql_script_length;
sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length);
if (sql_script == NULL) {
goto exit;
}
if (strlen(sql_script) != (size_t)sql_script_length) {
PyErr_SetString(PyExc_ValueError, "embedded null character");
goto exit;
}
return_value = pysqlite_cursor_executescript_impl(self, sql_script);

exit:
return return_value;
}

PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__,
"fetchone($self, /)\n"
"--\n"
Expand Down Expand Up @@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const
exit:
return return_value;
}
/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/
/*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/
Loading