23
23
from collections import defaultdict
24
24
25
25
import sqlalchemy as sa
26
- from sqlalchemy .sql import crud
26
+ from sqlalchemy .sql import crud , selectable
27
27
from sqlalchemy .sql import compiler
28
28
from .types import MutableDict
29
- from .sa_version import SA_1_1 , SA_VERSION
29
+ from .sa_version import SA_VERSION , SA_1_1 , SA_1_4
30
+
31
+
32
+ INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION = (1 , 0 , 1 )
30
33
31
34
32
35
def rewrite_update (clauseelement , multiparams , params ):
@@ -74,7 +77,18 @@ def rewrite_update(clauseelement, multiparams, params):
74
77
def crate_before_execute (conn , clauseelement , multiparams , params ):
75
78
is_crate = type (conn .dialect ).__name__ == 'CrateDialect'
76
79
if is_crate and isinstance (clauseelement , sa .sql .expression .Update ):
77
- return rewrite_update (clauseelement , multiparams , params )
80
+ if SA_VERSION >= SA_1_4 :
81
+ multiparams = ([params ],)
82
+ params = {}
83
+
84
+ clauseelement , multiparams , params = rewrite_update (clauseelement , multiparams , params )
85
+
86
+ if SA_VERSION >= SA_1_4 :
87
+ params = multiparams [0 ]
88
+ multiparams = []
89
+
90
+ return clauseelement , multiparams , params
91
+
78
92
return clauseelement , multiparams , params
79
93
80
94
@@ -189,9 +203,23 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
189
203
used to compile <sql.expression.Insert> expressions.
190
204
191
205
this function wraps insert_from_select statements inside
192
- parentheses to be conform with earlier versions of CreateDB.
206
+ parentheses to be conform with earlier versions of CreateDB.
207
+
208
+ According to the changelog, CrateDB >= 1.0.1 already mitigates this requirement:
209
+
210
+ ``INSERT`` statements now support ``SELECT`` statements without parentheses.
211
+ https://crate.io/docs/crate/reference/en/4.3/appendices/release-notes/1.0.1.html
193
212
"""
194
213
214
+ # Only CrateDB <= 1.0.0 needs parentheses for ``INSERT INTO ... SELECT ...``.
215
+ if self .dialect .server_version_info >= INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION :
216
+ return super (CrateCompiler , self ).visit_insert (insert_stmt , asfrom = asfrom , ** kw )
217
+
218
+ if SA_VERSION >= SA_1_4 :
219
+ raise DeprecationWarning (
220
+ "CrateDB version < {} not supported with SQLAlchemy 1.4" .format (
221
+ INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION ))
222
+
195
223
self .stack .append (
196
224
{'correlate_froms' : set (),
197
225
"asfrom_froms" : set (),
@@ -288,6 +316,9 @@ def visit_update(self, update_stmt, **kw):
288
316
Parts are taken from the SQLCompiler base class.
289
317
"""
290
318
319
+ if SA_VERSION >= SA_1_4 :
320
+ return self .visit_update_14 (update_stmt , ** kw )
321
+
291
322
if not update_stmt .parameters and \
292
323
not hasattr (update_stmt , '_crate_specific' ):
293
324
return super (CrateCompiler , self ).visit_update (update_stmt , ** kw )
@@ -311,11 +342,14 @@ def visit_update(self, update_stmt, **kw):
311
342
update_stmt , table_text
312
343
)
313
344
314
- crud_params = self ._get_crud_params (update_stmt , ** kw )
345
+ # CrateDB amendment.
346
+ crud_params = self ._get_crud_params (self , update_stmt , ** kw )
315
347
316
348
text += table_text
317
349
318
350
text += ' SET '
351
+
352
+ # CrateDB amendment begin.
319
353
include_table = extra_froms and \
320
354
self .render_table_with_column_in_update_from
321
355
@@ -333,6 +367,7 @@ def visit_update(self, update_stmt, **kw):
333
367
set_clauses .append (k + ' = ' + self .process (bindparam ))
334
368
335
369
text += ', ' .join (set_clauses )
370
+ # CrateDB amendment end.
336
371
337
372
if self .returning or update_stmt ._returning :
338
373
if not self .returning :
@@ -368,7 +403,6 @@ def visit_update(self, update_stmt, **kw):
368
403
369
404
def _get_crud_params (compiler , stmt , ** kw ):
370
405
""" extract values from crud parameters
371
-
372
406
taken from SQLAlchemy's crud module (since 1.0.x) and
373
407
adapted for Crate dialect"""
374
408
@@ -428,3 +462,295 @@ def _get_crud_params(compiler, stmt, **kw):
428
462
values , kw )
429
463
430
464
return values
465
+
466
+ def visit_update_14 (self , update_stmt , ** kw ):
467
+
468
+ compile_state = update_stmt ._compile_state_factory (
469
+ update_stmt , self , ** kw
470
+ )
471
+ update_stmt = compile_state .statement
472
+
473
+ toplevel = not self .stack
474
+ if toplevel :
475
+ self .isupdate = True
476
+ if not self .compile_state :
477
+ self .compile_state = compile_state
478
+
479
+ extra_froms = compile_state ._extra_froms
480
+ is_multitable = bool (extra_froms )
481
+
482
+ if is_multitable :
483
+ # main table might be a JOIN
484
+ main_froms = set (selectable ._from_objects (update_stmt .table ))
485
+ render_extra_froms = [
486
+ f for f in extra_froms if f not in main_froms
487
+ ]
488
+ correlate_froms = main_froms .union (extra_froms )
489
+ else :
490
+ render_extra_froms = []
491
+ correlate_froms = {update_stmt .table }
492
+
493
+ self .stack .append (
494
+ {
495
+ "correlate_froms" : correlate_froms ,
496
+ "asfrom_froms" : correlate_froms ,
497
+ "selectable" : update_stmt ,
498
+ }
499
+ )
500
+
501
+ text = "UPDATE "
502
+
503
+ if update_stmt ._prefixes :
504
+ text += self ._generate_prefixes (
505
+ update_stmt , update_stmt ._prefixes , ** kw
506
+ )
507
+
508
+ table_text = self .update_tables_clause (
509
+ update_stmt , update_stmt .table , render_extra_froms , ** kw
510
+ )
511
+
512
+ # CrateDB amendment.
513
+ crud_params = _get_crud_params_14 (
514
+ self , update_stmt , compile_state , ** kw
515
+ )
516
+
517
+ if update_stmt ._hints :
518
+ dialect_hints , table_text = self ._setup_crud_hints (
519
+ update_stmt , table_text
520
+ )
521
+ else :
522
+ dialect_hints = None
523
+
524
+ text += table_text
525
+
526
+ text += " SET "
527
+
528
+ # CrateDB amendment begin.
529
+ include_table = extra_froms and \
530
+ self .render_table_with_column_in_update_from
531
+
532
+ set_clauses = []
533
+
534
+ for c , expr , value in crud_params :
535
+ key = c ._compiler_dispatch (self , include_table = include_table )
536
+ clause = key + ' = ' + value
537
+ set_clauses .append (clause )
538
+
539
+ for k , v in compile_state ._dict_parameters .items ():
540
+ if isinstance (k , str ) and '[' in k :
541
+ bindparam = sa .sql .bindparam (k , v )
542
+ clause = k + ' = ' + self .process (bindparam )
543
+ set_clauses .append (clause )
544
+
545
+ text += ', ' .join (set_clauses )
546
+ # CrateDB amendment end.
547
+
548
+ if self .returning or update_stmt ._returning :
549
+ if self .returning_precedes_values :
550
+ text += " " + self .returning_clause (
551
+ update_stmt , self .returning or update_stmt ._returning
552
+ )
553
+
554
+ if extra_froms :
555
+ extra_from_text = self .update_from_clause (
556
+ update_stmt ,
557
+ update_stmt .table ,
558
+ render_extra_froms ,
559
+ dialect_hints ,
560
+ ** kw
561
+ )
562
+ if extra_from_text :
563
+ text += " " + extra_from_text
564
+
565
+ if update_stmt ._where_criteria :
566
+ t = self ._generate_delimited_and_list (
567
+ update_stmt ._where_criteria , ** kw
568
+ )
569
+ if t :
570
+ text += " WHERE " + t
571
+
572
+ limit_clause = self .update_limit_clause (update_stmt )
573
+ if limit_clause :
574
+ text += " " + limit_clause
575
+
576
+ if (
577
+ self .returning or update_stmt ._returning
578
+ ) and not self .returning_precedes_values :
579
+ text += " " + self .returning_clause (
580
+ update_stmt , self .returning or update_stmt ._returning
581
+ )
582
+
583
+ if self .ctes and toplevel :
584
+ text = self ._render_cte_clause () + text
585
+
586
+ self .stack .pop (- 1 )
587
+
588
+ return text
589
+
590
+
591
+ def _get_crud_params_14 (compiler , stmt , compile_state , ** kw ):
592
+ """create a set of tuples representing column/string pairs for use
593
+ in an INSERT or UPDATE statement.
594
+
595
+ Also generates the Compiled object's postfetch, prefetch, and
596
+ returning column collections, used for default handling and ultimately
597
+ populating the CursorResult's prefetch_cols() and postfetch_cols()
598
+ collections.
599
+
600
+ """
601
+ from sqlalchemy .sql .crud import _key_getters_for_crud_column
602
+ from sqlalchemy .sql .crud import _create_bind_param
603
+ from sqlalchemy .sql .crud import REQUIRED
604
+ from sqlalchemy .sql .crud import _get_stmt_parameter_tuples_params
605
+ from sqlalchemy .sql .crud import _get_multitable_params
606
+ from sqlalchemy .sql .crud import _scan_insert_from_select_cols
607
+ from sqlalchemy .sql .crud import _scan_cols
608
+ from sqlalchemy import exc
609
+ from sqlalchemy .sql .crud import _extend_values_for_multiparams
610
+
611
+ compiler .postfetch = []
612
+ compiler .insert_prefetch = []
613
+ compiler .update_prefetch = []
614
+ compiler .returning = []
615
+
616
+ # getters - these are normally just column.key,
617
+ # but in the case of mysql multi-table update, the rules for
618
+ # .key must conditionally take tablename into account
619
+ (
620
+ _column_as_key ,
621
+ _getattr_col_key ,
622
+ _col_bind_name ,
623
+ ) = getters = _key_getters_for_crud_column (compiler , stmt , compile_state )
624
+
625
+ compiler ._key_getters_for_crud_column = getters
626
+
627
+ # no parameters in the statement, no parameters in the
628
+ # compiled params - return binds for all columns
629
+ if compiler .column_keys is None and compile_state ._no_parameters :
630
+ return [
631
+ (
632
+ c ,
633
+ compiler .preparer .format_column (c ),
634
+ _create_bind_param (compiler , c , None , required = True ),
635
+ )
636
+ for c in stmt .table .columns
637
+ ]
638
+
639
+ if compile_state ._has_multi_parameters :
640
+ spd = compile_state ._multi_parameters [0 ]
641
+ stmt_parameter_tuples = list (spd .items ())
642
+ elif compile_state ._ordered_values :
643
+ spd = compile_state ._dict_parameters
644
+ stmt_parameter_tuples = compile_state ._ordered_values
645
+ elif compile_state ._dict_parameters :
646
+ spd = compile_state ._dict_parameters
647
+ stmt_parameter_tuples = list (spd .items ())
648
+ else :
649
+ stmt_parameter_tuples = spd = None
650
+
651
+ # if we have statement parameters - set defaults in the
652
+ # compiled params
653
+ if compiler .column_keys is None :
654
+ parameters = {}
655
+ elif stmt_parameter_tuples :
656
+ parameters = dict (
657
+ (_column_as_key (key ), REQUIRED )
658
+ for key in compiler .column_keys
659
+ if key not in spd
660
+ )
661
+ else :
662
+ parameters = dict (
663
+ (_column_as_key (key ), REQUIRED ) for key in compiler .column_keys
664
+ )
665
+
666
+ # create a list of column assignment clauses as tuples
667
+ values = []
668
+
669
+ if stmt_parameter_tuples is not None :
670
+ _get_stmt_parameter_tuples_params (
671
+ compiler ,
672
+ compile_state ,
673
+ parameters ,
674
+ stmt_parameter_tuples ,
675
+ _column_as_key ,
676
+ values ,
677
+ kw ,
678
+ )
679
+
680
+ check_columns = {}
681
+
682
+ # special logic that only occurs for multi-table UPDATE
683
+ # statements
684
+ if compile_state .isupdate and compile_state .is_multitable :
685
+ _get_multitable_params (
686
+ compiler ,
687
+ stmt ,
688
+ compile_state ,
689
+ stmt_parameter_tuples ,
690
+ check_columns ,
691
+ _col_bind_name ,
692
+ _getattr_col_key ,
693
+ values ,
694
+ kw ,
695
+ )
696
+
697
+ if compile_state .isinsert and stmt ._select_names :
698
+ _scan_insert_from_select_cols (
699
+ compiler ,
700
+ stmt ,
701
+ compile_state ,
702
+ parameters ,
703
+ _getattr_col_key ,
704
+ _column_as_key ,
705
+ _col_bind_name ,
706
+ check_columns ,
707
+ values ,
708
+ kw ,
709
+ )
710
+ else :
711
+ _scan_cols (
712
+ compiler ,
713
+ stmt ,
714
+ compile_state ,
715
+ parameters ,
716
+ _getattr_col_key ,
717
+ _column_as_key ,
718
+ _col_bind_name ,
719
+ check_columns ,
720
+ values ,
721
+ kw ,
722
+ )
723
+
724
+ # CrateDB amendment.
725
+ """
726
+ if parameters and stmt_parameter_tuples:
727
+ check = (
728
+ set(parameters)
729
+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
730
+ .difference(check_columns)
731
+ )
732
+ if check:
733
+ raise exc.CompileError(
734
+ "Unconsumed column names: %s"
735
+ % (", ".join("%s" % (c,) for c in check))
736
+ )
737
+ """
738
+
739
+ if compile_state ._has_multi_parameters :
740
+ values = _extend_values_for_multiparams (
741
+ compiler , stmt , compile_state , values , kw
742
+ )
743
+ elif not values and compiler .for_executemany :
744
+ # convert an "INSERT DEFAULT VALUES"
745
+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
746
+ # into an in-place multi values. This supports
747
+ # insert_executemany_returning mode :)
748
+ values = [
749
+ (
750
+ stmt .table .columns [0 ],
751
+ compiler .preparer .format_column (stmt .table .columns [0 ]),
752
+ "DEFAULT" ,
753
+ )
754
+ ]
755
+
756
+ return values
0 commit comments