@@ -44,7 +44,7 @@ def logical_xor(conditions):
44
44
45
45
46
46
def count (field ):
47
- return SQLFunction (convert_func ("count" ), field or Scalar ('1' ))
47
+ return SQLFunction (convert_func ("count" ), field or Literal ('1' ))
48
48
49
49
50
50
def order_descending (field ):
@@ -55,7 +55,7 @@ def order_ascending(field):
55
55
return _SQLOrdering (field , SQL_ASC )
56
56
57
57
58
- class Scalar (object ):
58
+ class Literal (object ):
59
59
def __init__ (self , value ):
60
60
self .value = value
61
61
@@ -84,14 +84,23 @@ def __init__(self, field, direction):
84
84
self .direction = direction
85
85
86
86
87
+ TableOptions = namedtuple (
88
+ 'TableOptions' ,
89
+ [
90
+ 'schema' ,
91
+ 'name' ,
92
+ 'alias'
93
+ ]
94
+ )
95
+
96
+
87
97
JoinOptions = namedtuple (
88
98
'JoinOptions' ,
89
99
[
90
100
'join_type' ,
91
101
'main_field' ,
92
102
'join_field' ,
93
- 'table' ,
94
- 'table_alias'
103
+ 'table'
95
104
]
96
105
)
97
106
@@ -114,7 +123,6 @@ def __init__(self, field, direction):
114
123
'insert_ignore' ,
115
124
'insert_replace' ,
116
125
'join' ,
117
- 'table_alias' ,
118
126
]
119
127
)
120
128
@@ -130,11 +138,10 @@ class QueryBuilder(object):
130
138
chaining and the ability to easily reuse queries.
131
139
"""
132
140
def __init__ (self , query_data = None ):
133
- self ._query_data = query_data or _empty_query_data
134
- self ._table_alias_gen = itertools .cycle (string .ascii_lowercase )
135
- self ._query_data = self ._query_data ._replace (
136
- table_alias = (next (self ._table_alias_gen ))
137
- )
141
+ if not query_data :
142
+ query_data = _empty_query_data
143
+
144
+ self ._query_data = query_data
138
145
139
146
def _replace (self , ** kwargs ):
140
147
return self .copy (self ._query_data ._replace (** kwargs ))
@@ -184,7 +191,7 @@ def delete(self):
184
191
"""
185
192
return self ._replace (delete = True )
186
193
187
- def on_table (self , table ):
194
+ def on_table (self , table , schema = None ):
188
195
"""
189
196
Identifies the main table the query should be executed upon. E.g. if
190
197
`table` were `users` then the equivalent result would be:
@@ -194,6 +201,10 @@ def on_table(self, table):
194
201
SELECT * FROM users
195
202
196
203
"""
204
+ if not self ._query_data .table :
205
+ table = TableOptions (name = table , schema = schema , alias = None )
206
+ else :
207
+ table = self ._query_data .table ._replace (table = table , schema = schema )
197
208
return self ._replace (table = table )
198
209
199
210
def on_duplicate_key_update (self , ** col_values ):
@@ -250,7 +261,7 @@ def where(self, *conditions):
250
261
assert conditions
251
262
return self ._replace (where = logical_and (conditions ))
252
263
253
- def join (self , join_table , main_field , join_field = None ):
264
+ def join (self , join_table , main_field , join_field = None , schema = None ):
254
265
"""
255
266
Joins the current query with the given *join_table* on *join_field*
256
267
which is a field on *join_table* and *main_field* which is a field
@@ -264,8 +275,11 @@ def join(self, join_table, main_field, join_field=None):
264
275
join_type = SQL_JOIN_TYPE_INNER ,
265
276
main_field = main_field ,
266
277
join_field = join_field or main_field ,
267
- table = join_table ,
268
- table_alias = next (self ._table_alias_gen )
278
+ table = TableOptions (
279
+ name = join_table ,
280
+ schema = schema ,
281
+ alias = None
282
+ )
269
283
)
270
284
)
271
285
@@ -363,50 +377,64 @@ def _query_joiner(query, iterable, join_with=", "):
363
377
query .append (join_with )
364
378
365
379
366
- class QueryString (list ):
367
- def append (self , value , spaced_left = False , spaced_right = False ):
368
- super (QueryString , self ).append (value )
369
-
370
-
371
380
class SQLCompiler (object ):
372
- def __init__ (self , query_data ):
381
+ def __init__ (self , query_data , alias_gen = None ):
382
+ # generate the aliases
383
+ if alias_gen :
384
+ self .alias_gen = alias_gen
385
+ else :
386
+ self .alias_gen = itertools .cycle (string .ascii_lowercase )
387
+
388
+ query_data = query_data ._replace (
389
+ table = query_data .table ._replace (alias = next (self .alias_gen ))
390
+ )
391
+ if query_data .join :
392
+ query_data = query_data ._replace (
393
+ join = query_data .join ._replace (
394
+ table = query_data .join .table ._replace (
395
+ alias = next (self .alias_gen )
396
+ )
397
+ )
398
+ )
373
399
self .query_data = query_data
374
400
375
401
def _encode_main_table_name (self , include_alias = True ):
376
402
return encode_table_name (
377
- self .query_data .table ,
378
- self .query_data .table_alias ,
403
+ self .query_data .table .name ,
404
+ self .query_data .table .alias ,
405
+ self .query_data .table .schema ,
379
406
include_alias = include_alias
380
407
)
381
408
382
409
def _encode_join_table_name (self ):
383
410
return encode_table_name (
384
- self .query_data .join .table ,
385
- self .query_data .join .table_alias ,
411
+ self .query_data .join .table .name ,
412
+ self .query_data .join .table .alias ,
413
+ self .query_data .join .table .schema ,
386
414
include_alias = True
387
415
)
388
416
389
417
def _encode_field (self , field ):
390
418
return encode_field (
391
419
field ,
392
- self .query_data .table ,
393
- self .query_data .table_alias ,
420
+ self .query_data .table . name ,
421
+ self .query_data .table . alias ,
394
422
include_alias = True
395
423
)
396
424
397
425
def _encode_join_field (self , field ):
398
426
return encode_field (
399
427
field ,
400
- self .query_data .join .table ,
401
- self .query_data .join .table_alias ,
428
+ self .query_data .join .table . name ,
429
+ self .query_data .join .table . alias ,
402
430
include_alias = True
403
431
)
404
432
405
433
def _smart_encode_field (self , field ):
406
434
if (
407
435
self .query_data .join and
408
436
isinstance (field , string_types ) and
409
- field .startswith (self .query_data .join .table + '.' )
437
+ field .startswith (self .query_data .join .table . name + '.' )
410
438
):
411
439
return self ._encode_join_field (field )
412
440
@@ -527,18 +555,25 @@ def _generate_single_where_clause(self, field_op, value):
527
555
clause = []
528
556
with in_brackets (clause ):
529
557
clause .extend ([field , convert_op (op )])
530
- args = [value ]
531
- if (
558
+ if isinstance (value , QueryBuilder ):
559
+ with in_brackets (clause ):
560
+ sql , sql_args = SQLCompiler (value ._query_data ,
561
+ self .alias_gen )._raw_sql ()
562
+ clause .extend (sql )
563
+ args = list (sql_args )
564
+ elif (
532
565
not isinstance (value , string_types ) and
533
566
isinstance (value , collections .Iterable )
534
567
):
535
- clause .append (u"({})" .format (u"," .join ([u"%s" ] * len (value ))))
568
+ args = list (value )
569
+ clause .append (u"({})" .format (u"," .join ([u"%s" ] * len (args ))))
536
570
elif value is None :
537
571
clause .append (SQL_NULL )
538
572
# we get rid of the value as it is represented as null
539
- del args [: ]
573
+ args = [ ]
540
574
else :
541
575
clause .append (u"%s" )
576
+ args = [value ]
542
577
543
578
return clause , args
544
579
@@ -644,7 +679,7 @@ def _generate_query_operation(self):
644
679
645
680
raise InvalidQueryException
646
681
647
- def sql (self ):
682
+ def _raw_sql (self ):
648
683
if not self .query_data .table :
649
684
raise Exception ("requires both select and from" )
650
685
@@ -659,8 +694,12 @@ def sql(self):
659
694
sql , sql_args = zip (
660
695
main , where , group_by , having , order_by , offset , limit
661
696
)
697
+ return itertools .chain (* sql ), tuple (itertools .chain (* sql_args ))
698
+
699
+ def sql (self ):
700
+ sql , sql_args = self ._raw_sql ()
662
701
663
702
return (
664
- serialize_query_tokens (itertools . chain ( * sql ) ),
665
- tuple ( itertools . chain ( * sql_args ))
703
+ serialize_query_tokens (sql ),
704
+ sql_args
666
705
)
0 commit comments