diff --git a/doc/src/sql.rst b/doc/src/sql.rst index f044372dd..a2ed15e00 100644 --- a/doc/src/sql.rst +++ b/doc/src/sql.rst @@ -51,7 +51,8 @@ from the query parameters:: from psycopg2 import sql cur.execute( - sql.SQL("insert into %s values (%%s, %%s)") % [sql.Identifier('my_table')], + sql.SQL("insert into {} values (%s, %s)") + .format(sql.Identifier('my_table')), [10, 20]) diff --git a/lib/sql.py b/lib/sql.py index ff7faa1ac..23f66a616 100644 --- a/lib/sql.py +++ b/lib/sql.py @@ -23,24 +23,27 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. -import re import sys -import collections +import string from psycopg2 import extensions as ext +_formatter = string.Formatter() + + class Composable(object): """ Abstract base class for objects that can be used to compose an SQL string. - Composables can be passed directly to `~cursor.execute()` and + `!Composable` objects can be passed directly to `~cursor.execute()` and `~cursor.executemany()`. - Composables can be joined using the ``+`` operator: the result will be - a `Composed` instance containing the objects joined. The operator ``*`` is - also supported with an integer argument: the result is a `!Composed` - instance containing the left argument repeated as many times as requested. + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. .. automethod:: as_string """ @@ -144,21 +147,22 @@ def join(self, joiner): class SQL(Composable): """ - A `Composable` representing a snippet of SQL string to be included verbatim. + A `Composable` representing a snippet of SQL statement. - `!SQL` supports the ``%`` operator to incorporate variable parts of a query - into a template: the operator takes a sequence or mapping of `Composable` - (according to the style of the placeholders in the *string*) and returning - a `Composed` object. + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). Example:: - >>> query = sql.SQL("select %s from %s") % [ + >>> query = sql.SQL("select {} from {}").format( ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), - ... sql.Identifier('table')] + ... sql.Identifier('table')) >>> print(query.as_string(conn)) select "foo", "bar" from "table" + .. automethod:: format + .. automethod:: join """ def __init__(self, string): @@ -169,12 +173,73 @@ def __init__(self, string): def __repr__(self): return "sql.SQL(%r)" % (self._wrapped,) - def __mod__(self, args): - return _compose(self._wrapped, args) - def as_string(self, conn_or_curs): return self._wrapped + def format(self, *args, **kwargs): + """ + Merge `Composable` objects into a template. + + :param `Composable` args: parameters to replace to numbered + (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders + :param `Composable` kwargs: parameters to replace to named (``{name}``) + placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``), numbered (``{0}``, + ``{1}``...), and named placeholders (``{name}``), with positional + arguments replacing the numbered placeholders and keywords replacing + the named ones. However placeholder modifiers (``{{0!r}}``, + ``{{0:<10}}``) are not supported. Only `!Composable` objects can be + passed to the template. + + Example:: + + >>> print(sql.SQL("select * from {} where {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + >>> print(sql.SQL("select * from {tbl} where {pkey} = %s") + ... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + """ + rv = [] + autonum = 0 + for pre, name, spec, conv in _formatter.parse(self._wrapped): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual") + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic") + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + def join(self, seq): """ Join a sequence of `Composable` or a `Composed` and return a `!Composed`. @@ -183,7 +248,8 @@ def join(self, seq): Example:: - >>> snip - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', 'baz'])) + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) >>> print(snip.as_string(conn)) "foo", "bar", "baz" """ @@ -331,123 +397,6 @@ def as_string(self, conn_or_curs): return "%s" -re_compose = re.compile(""" - % # percent sign - (?: - ([%s]) # either % or s - | \( ([^\)]+) \) s # or a (named)s placeholder (named captured) - ) - """, re.VERBOSE) - - -def _compose(sql, args=None): - """ - Merge an SQL string with some variable parts. - - The *sql* string can contain placeholders such as `%s` or `%(name)s`. - If the string must contain a literal ``%`` symbol use ``%%``. Note that, - unlike `~cursor.execute()`, the replacement ``%%`` |=>| ``%`` is *always* - performed, even if there is no argument. - - .. |=>| unicode:: 0x21D2 .. double right arrow - - *args* must be a sequence or mapping (according to the placeholder style) - of `Composable` instances. - - The value returned is a `Composed` instance obtained replacing the - arguments to the query placeholders. - """ - if args is None: - args = () - - phs = list(re_compose.finditer(sql)) - - # check placeholders consistent - counts = {'%': 0, 's': 0, None: 0} - for ph in phs: - counts[ph.group(1)] += 1 - - npos = counts['s'] - nnamed = counts[None] - - if npos and nnamed: - raise ValueError( - "the sql string contains both named and positional placeholders") - - elif npos: - if not isinstance(args, collections.Sequence): - raise TypeError( - "the sql string expects values in a sequence, got %s instead" - % type(args).__name__) - - if len(args) != npos: - raise ValueError( - "the sql string expects %s values, got %s" % (npos, len(args))) - - return _compose_seq(sql, phs, args) - - elif nnamed: - if not isinstance(args, collections.Mapping): - raise TypeError( - "the sql string expects values in a mapping, got %s instead" - % type(args)) - - return _compose_map(sql, phs, args) - - else: - if isinstance(args, collections.Sequence) and args: - raise ValueError( - "the sql string expects no value, got %s instead" % len(args)) - # If args are a mapping, no placeholder is an acceptable case - - # Convert %% into % - return _compose_seq(sql, phs, ()) - - -def _compose_seq(sql, phs, args): - rv = [] - j = 0 - for i, ph in enumerate(phs): - if i: - rv.append(SQL(sql[phs[i - 1].end():ph.start()])) - else: - rv.append(SQL(sql[0:ph.start()])) - - if ph.group(1) == 's': - rv.append(args[j]) - j += 1 - else: - rv.append(SQL('%')) - - if phs: - rv.append(SQL(sql[phs[-1].end():])) - else: - rv.append(SQL(sql)) - - return Composed(rv) - - -def _compose_map(sql, phs, args): - rv = [] - for i, ph in enumerate(phs): - if i: - rv.append(SQL(sql[phs[i - 1].end():ph.start()])) - else: - rv.append(SQL(sql[0:ph.start()])) - - if ph.group(2): - rv.append(args[ph.group(2)]) - else: - rv.append(SQL('%')) - - if phs: - rv.append(SQL(sql[phs[-1].end():])) - else: - rv.append(sql) - - return Composed(rv) - - # Alias PH = Placeholder diff --git a/tests/test_sql.py b/tests/test_sql.py index 5482ccb5b..c2268fda8 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -25,65 +25,89 @@ import datetime as dt from testutils import unittest, ConnectingTestCase +import psycopg2 from psycopg2 import sql -class ComposeTests(ConnectingTestCase): +class SqlFormatTests(ConnectingTestCase): def test_pos(self): - s = sql.SQL("select %s from %s") \ - % (sql.Identifier('field'), sql.Identifier('table')) + s = sql.SQL("select {} from {}").format( + sql.Identifier('field'), sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_pos_spec(self): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier('field'), sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier('table'), sql.Identifier('field')) s1 = s.as_string(self.conn) self.assert_(isinstance(s1, str)) self.assertEqual(s1, 'select "field" from "table"') def test_dict(self): - s = sql.SQL("select %(f)s from %(t)s") \ - % {'f': sql.Identifier('field'), 't': sql.Identifier('table')} + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier('field'), t=sql.Identifier('table')) s1 = s.as_string(self.conn) self.assert_(isinstance(s1, str)) self.assertEqual(s1, 'select "field" from "table"') def test_unicode(self): - s = sql.SQL(u"select %s from %s") \ - % (sql.Identifier(u'field'), sql.Identifier('table')) + s = sql.SQL(u"select {} from {}").format( + sql.Identifier(u'field'), sql.Identifier('table')) s1 = s.as_string(self.conn) self.assert_(isinstance(s1, unicode)) self.assertEqual(s1, u'select "field" from "table"') def test_compose_literal(self): - s = sql.SQL("select %s;") % [sql.Literal(dt.date(2016, 12, 31))] + s = sql.SQL("select {};").format(sql.Literal(dt.date(2016, 12, 31))) s1 = s.as_string(self.conn) self.assertEqual(s1, "select '2016-12-31'::date;") def test_compose_empty(self): - s = sql.SQL("select foo;") % () + s = sql.SQL("select foo;").format() s1 = s.as_string(self.conn) self.assertEqual(s1, "select foo;") def test_percent_escape(self): - s = sql.SQL("42 %% %s") % [sql.Literal(7)] + s = sql.SQL("42 % {}").format(sql.Literal(7)) s1 = s.as_string(self.conn) self.assertEqual(s1, "42 % 7") - s = sql.SQL("42 %% 7") % [] - s1 = s.as_string(self.conn) - self.assertEqual(s1, "42 % 7") + def test_braces_escape(self): + s = sql.SQL("{{{}}}").format(sql.Literal(7)) + self.assertEqual(s.as_string(self.conn), "{7}") + s = sql.SQL("{{1,{}}}").format(sql.Literal(7)) + self.assertEqual(s.as_string(self.conn), "{1,7}") def test_compose_badnargs(self): - self.assertRaises(ValueError, sql.SQL("select foo;").__mod__, [10]) - self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, []) - self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, [10, 20]) + self.assertRaises(IndexError, sql.SQL("select {};").format) + self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20) + self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20) def test_compose_bad_args_type(self): - self.assertRaises(TypeError, sql.SQL("select %s;").__mod__, {'a': 10}) - self.assertRaises(TypeError, sql.SQL("select %(x)s;").__mod__, [10]) + self.assertRaises(IndexError, sql.SQL("select {};").format, a=10) + self.assertRaises(KeyError, sql.SQL("select {x};").format, 10) + + def test_must_be_composable(self): + self.assertRaises(TypeError, sql.SQL("select {};").format, 'foo') + self.assertRaises(TypeError, sql.SQL("select {};").format, 10) + + def test_no_modifiers(self): + self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10) + self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10) def test_must_be_adaptable(self): class Foo(object): pass - self.assertRaises(TypeError, - sql.SQL("select %s;").__mod__, [Foo()]) + self.assertRaises(psycopg2.ProgrammingError, + sql.SQL("select {};").format(sql.Literal(Foo())).as_string, self.conn) def test_execute(self): cur = self.conn.cursor() @@ -93,11 +117,10 @@ def test_execute(self): foo text, bar text, "ba'z" text) """) cur.execute( - sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [ + sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', '), - ], + (sql.PH() * 3).join(', ')), (10, 'a', 'b', 'c')) cur.execute("select * from test_compose") @@ -111,11 +134,10 @@ def test_executemany(self): foo text, bar text, "ba'z" text) """) cur.executemany( - sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [ + sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', '), - ], + (sql.PH() * 3).join(', ')), [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) cur.execute("select * from test_compose") @@ -169,6 +191,13 @@ def test_repr(self): sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), "'2017-01-01'::date") + def test_must_be_adaptable(self): + class Foo(object): + pass + + self.assertRaises(psycopg2.ProgrammingError, + sql.Literal(Foo()).as_string, self.conn) + class SQLTests(ConnectingTestCase): def test_class(self):