Skip to content

Commit 0ff8f57

Browse files
committed
Add subquery support
1 parent 9805bcc commit 0ff8f57

File tree

5 files changed

+143
-89
lines changed

5 files changed

+143
-89
lines changed

README.md

Lines changed: 0 additions & 50 deletions
This file was deleted.

README.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ A simple order by statement
4141
>>> select("id").on_table("users").order_by("id").sql()
4242
(u'SELECT `a`.`id` FROM `users` AS `a` ORDER BY `a`.`id`', ())
4343
44+
A subquery
45+
~~~~~~~~~~
46+
.. code-block:: python
47+
48+
>>> from sqlquery.queryapi import select
49+
>>> select(
50+
"username", "id"
51+
).on_table("users").where(
52+
("id__in", select("id").on_table("banned_users"))
53+
).sql()
54+
(u'SELECT `a`.`username`, `a`.`id` FROM `users` AS `a` WHERE (`a`.`id` IN (SELECT `b`.`id` FROM `banned_users` AS `b`))',
55+
())
56+
57+
4458
A more involved statement
4559
~~~~~~~~~~~~~~~~~~~~~~~~~
4660
.. code-block:: python

sqlquery/_querybuilder.py

Lines changed: 75 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def logical_xor(conditions):
4444

4545

4646
def count(field):
47-
return SQLFunction(convert_func("count"), field or Scalar('1'))
47+
return SQLFunction(convert_func("count"), field or Literal('1'))
4848

4949

5050
def order_descending(field):
@@ -55,7 +55,7 @@ def order_ascending(field):
5555
return _SQLOrdering(field, SQL_ASC)
5656

5757

58-
class Scalar(object):
58+
class Literal(object):
5959
def __init__(self, value):
6060
self.value = value
6161

@@ -84,14 +84,23 @@ def __init__(self, field, direction):
8484
self.direction = direction
8585

8686

87+
TableOptions = namedtuple(
88+
'TableOptions',
89+
[
90+
'schema',
91+
'name',
92+
'alias'
93+
]
94+
)
95+
96+
8797
JoinOptions = namedtuple(
8898
'JoinOptions',
8999
[
90100
'join_type',
91101
'main_field',
92102
'join_field',
93-
'table',
94-
'table_alias'
103+
'table'
95104
]
96105
)
97106

@@ -114,7 +123,6 @@ def __init__(self, field, direction):
114123
'insert_ignore',
115124
'insert_replace',
116125
'join',
117-
'table_alias',
118126
]
119127
)
120128

@@ -130,11 +138,10 @@ class QueryBuilder(object):
130138
chaining and the ability to easily reuse queries.
131139
"""
132140
def __init__(self, query_data=None):
133-
self._query_data = query_data or _empty_query_data
134-
self._table_alias_gen = itertools.cycle(string.ascii_lowercase)
135-
self._query_data = self._query_data._replace(
136-
table_alias=(next(self._table_alias_gen))
137-
)
141+
if not query_data:
142+
query_data = _empty_query_data
143+
144+
self._query_data = query_data
138145

139146
def _replace(self, **kwargs):
140147
return self.copy(self._query_data._replace(**kwargs))
@@ -184,7 +191,7 @@ def delete(self):
184191
"""
185192
return self._replace(delete=True)
186193

187-
def on_table(self, table):
194+
def on_table(self, table, schema=None):
188195
"""
189196
Identifies the main table the query should be executed upon. E.g. if
190197
`table` were `users` then the equivalent result would be:
@@ -194,6 +201,10 @@ def on_table(self, table):
194201
SELECT * FROM users
195202
196203
"""
204+
if not self._query_data.table:
205+
table = TableOptions(name=table, schema=schema, alias=None)
206+
else:
207+
table = self._query_data.table._replace(table=table, schema=schema)
197208
return self._replace(table=table)
198209

199210
def on_duplicate_key_update(self, **col_values):
@@ -250,7 +261,7 @@ def where(self, *conditions):
250261
assert conditions
251262
return self._replace(where=logical_and(conditions))
252263

253-
def join(self, join_table, main_field, join_field=None):
264+
def join(self, join_table, main_field, join_field=None, schema=None):
254265
"""
255266
Joins the current query with the given *join_table* on *join_field*
256267
which is a field on *join_table* and *main_field* which is a field
@@ -264,8 +275,11 @@ def join(self, join_table, main_field, join_field=None):
264275
join_type=SQL_JOIN_TYPE_INNER,
265276
main_field=main_field,
266277
join_field=join_field or main_field,
267-
table=join_table,
268-
table_alias=next(self._table_alias_gen)
278+
table=TableOptions(
279+
name=join_table,
280+
schema=schema,
281+
alias=None
282+
)
269283
)
270284
)
271285

@@ -363,50 +377,64 @@ def _query_joiner(query, iterable, join_with=", "):
363377
query.append(join_with)
364378

365379

366-
class QueryString(list):
367-
def append(self, value, spaced_left=False, spaced_right=False):
368-
super(QueryString, self).append(value)
369-
370-
371380
class SQLCompiler(object):
372-
def __init__(self, query_data):
381+
def __init__(self, query_data, alias_gen=None):
382+
# generate the aliases
383+
if alias_gen:
384+
self.alias_gen = alias_gen
385+
else:
386+
self.alias_gen = itertools.cycle(string.ascii_lowercase)
387+
388+
query_data = query_data._replace(
389+
table=query_data.table._replace(alias=next(self.alias_gen))
390+
)
391+
if query_data.join:
392+
query_data = query_data._replace(
393+
join=query_data.join._replace(
394+
table=query_data.join.table._replace(
395+
alias=next(self.alias_gen)
396+
)
397+
)
398+
)
373399
self.query_data = query_data
374400

375401
def _encode_main_table_name(self, include_alias=True):
376402
return encode_table_name(
377-
self.query_data.table,
378-
self.query_data.table_alias,
403+
self.query_data.table.name,
404+
self.query_data.table.alias,
405+
self.query_data.table.schema,
379406
include_alias=include_alias
380407
)
381408

382409
def _encode_join_table_name(self):
383410
return encode_table_name(
384-
self.query_data.join.table,
385-
self.query_data.join.table_alias,
411+
self.query_data.join.table.name,
412+
self.query_data.join.table.alias,
413+
self.query_data.join.table.schema,
386414
include_alias=True
387415
)
388416

389417
def _encode_field(self, field):
390418
return encode_field(
391419
field,
392-
self.query_data.table,
393-
self.query_data.table_alias,
420+
self.query_data.table.name,
421+
self.query_data.table.alias,
394422
include_alias=True
395423
)
396424

397425
def _encode_join_field(self, field):
398426
return encode_field(
399427
field,
400-
self.query_data.join.table,
401-
self.query_data.join.table_alias,
428+
self.query_data.join.table.name,
429+
self.query_data.join.table.alias,
402430
include_alias=True
403431
)
404432

405433
def _smart_encode_field(self, field):
406434
if (
407435
self.query_data.join and
408436
isinstance(field, string_types) and
409-
field.startswith(self.query_data.join.table + '.')
437+
field.startswith(self.query_data.join.table.name + '.')
410438
):
411439
return self._encode_join_field(field)
412440

@@ -527,18 +555,25 @@ def _generate_single_where_clause(self, field_op, value):
527555
clause = []
528556
with in_brackets(clause):
529557
clause.extend([field, convert_op(op)])
530-
args = [value]
531-
if (
558+
if isinstance(value, QueryBuilder):
559+
with in_brackets(clause):
560+
sql, sql_args = SQLCompiler(value._query_data,
561+
self.alias_gen)._raw_sql()
562+
clause.extend(sql)
563+
args = list(sql_args)
564+
elif (
532565
not isinstance(value, string_types) and
533566
isinstance(value, collections.Iterable)
534567
):
535-
clause.append(u"({})".format(u",".join([u"%s"] * len(value))))
568+
args = list(value)
569+
clause.append(u"({})".format(u",".join([u"%s"] * len(args))))
536570
elif value is None:
537571
clause.append(SQL_NULL)
538572
# we get rid of the value as it is represented as null
539-
del args[:]
573+
args = []
540574
else:
541575
clause.append(u"%s")
576+
args = [value]
542577

543578
return clause, args
544579

@@ -644,7 +679,7 @@ def _generate_query_operation(self):
644679

645680
raise InvalidQueryException
646681

647-
def sql(self):
682+
def _raw_sql(self):
648683
if not self.query_data.table:
649684
raise Exception("requires both select and from")
650685

@@ -659,8 +694,12 @@ def sql(self):
659694
sql, sql_args = zip(
660695
main, where, group_by, having, order_by, offset, limit
661696
)
697+
return itertools.chain(*sql), tuple(itertools.chain(*sql_args))
698+
699+
def sql(self):
700+
sql, sql_args = self._raw_sql()
662701

663702
return (
664-
serialize_query_tokens(itertools.chain(*sql)),
665-
tuple(itertools.chain(*sql_args))
703+
serialize_query_tokens(sql),
704+
sql_args
666705
)

sqlquery/sqlencoding.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ def encode_field(field, table_name, table_alias, include_alias=True):
6464
return quoted(prefix) + '.' + quoted(field)
6565

6666

67-
def encode_table_name(table_name, table_alias, include_alias=True):
67+
def encode_table_name(table_name, table_alias, table_schema,
68+
include_alias=True):
69+
if table_schema:
70+
name = "{}.{}".format(quoted(table_schema), quoted(table_name))
71+
else:
72+
name = quoted(table_name)
73+
6874
if not include_alias:
69-
return quoted(table_name)
75+
return name
7076

7177
return (
72-
quoted(table_name)
78+
name
7379
+ " AS "
7480
+ quoted(table_alias)
7581
)

tests/test_query_builder.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ def test__generate_select_count_field(self):
9595
(serialize_query_tokens(sql), args)
9696
)
9797

98+
def test__generate_select_with_schema(self):
99+
compiler = self.builder.select(
100+
'users'
101+
).on_table("table", schema='myschema').compiler()
102+
103+
sql, args = compiler._generate_select()
104+
105+
self.assertEqual(
106+
("SELECT `a`.`users` FROM `myschema`.`table` AS `a`", []),
107+
(serialize_query_tokens(sql), args)
108+
)
109+
98110
def test__generate_select_no_element_raises(self):
99111
with self.assertRaises(InvalidQueryException):
100112
self.builder.select().on_table(
@@ -135,6 +147,39 @@ def test__generate_all_where_ops(self):
135147
args
136148
)
137149

150+
def test__generate_where_in_literals(self):
151+
compiler = self.basic_select.where(
152+
("test1__in", [1, 2, 3])
153+
).compiler()
154+
155+
sql, args = compiler._generate_where()
156+
157+
self.assertEqual(
158+
"WHERE (`a`.`test1` IN (%s,%s,%s))",
159+
serialize_query_tokens(sql)
160+
)
161+
self.assertEqual(args, [1, 2, 3])
162+
163+
def test__generate_where_in_query(self):
164+
compiler = self.basic_select.where(
165+
(
166+
"test1__in",
167+
self.builder.select("id").on_table("table2").where(
168+
("test2__eq", 4)
169+
)
170+
)
171+
).compiler()
172+
173+
sql, args = compiler._generate_where()
174+
175+
self.assertEqual(
176+
"WHERE (`a`.`test1` IN (SELECT `b`.`id` FROM `table2` AS `b` "
177+
"WHERE (`b`.`test2` <=> %s)))"
178+
,
179+
serialize_query_tokens(sql)
180+
)
181+
self.assertEqual(args, [4])
182+
138183
def test__generate_is_null(self):
139184
for (op, sql_op) in {"is": "IS NULL", "isnot": "IS NOT NULL"}.items():
140185
compiler = self.basic_select.where(

0 commit comments

Comments
 (0)