Skip to content

Commit 5d8a543

Browse files
committed
Add SQLAlchemy 1.4 compatibility for CrateCompiler
> A major initiative in the 1.4 series is to approach the model of both Core SQL statements as well as the ORM Query to allow for an efficient, cacheable model of statement creation and compilation, where the compilation step would be cached, based on a cache key generated by the created statement object, which itself is newly created for each use.
1 parent 9c3b47a commit 5d8a543

File tree

5 files changed

+395
-6
lines changed

5 files changed

+395
-6
lines changed

src/crate/client/sqlalchemy/compiler.py

Lines changed: 332 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from collections import defaultdict
2424

2525
import sqlalchemy as sa
26-
from sqlalchemy.sql import crud
26+
from sqlalchemy.sql import crud, selectable
2727
from sqlalchemy.sql import compiler
2828
from .types import MutableDict
29-
from .sa_version import SA_1_1, SA_VERSION
29+
from .sa_version import SA_VERSION, SA_1_1, SA_1_4
30+
31+
32+
INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION = (1, 0, 1)
3033

3134

3235
def rewrite_update(clauseelement, multiparams, params):
@@ -74,7 +77,18 @@ def rewrite_update(clauseelement, multiparams, params):
7477
def crate_before_execute(conn, clauseelement, multiparams, params):
7578
is_crate = type(conn.dialect).__name__ == 'CrateDialect'
7679
if is_crate and isinstance(clauseelement, sa.sql.expression.Update):
77-
return rewrite_update(clauseelement, multiparams, params)
80+
if SA_VERSION >= SA_1_4:
81+
multiparams = ([params],)
82+
params = {}
83+
84+
clauseelement, multiparams, params = rewrite_update(clauseelement, multiparams, params)
85+
86+
if SA_VERSION >= SA_1_4:
87+
params = multiparams[0]
88+
multiparams = []
89+
90+
return clauseelement, multiparams, params
91+
7892
return clauseelement, multiparams, params
7993

8094

@@ -189,9 +203,23 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
189203
used to compile <sql.expression.Insert> expressions.
190204
191205
this function wraps insert_from_select statements inside
192-
parentheses to be conform with earlier versions of CreateDB.
206+
parentheses to be conform with earlier versions of CreateDB.
207+
208+
According to the changelog, CrateDB >= 1.0.1 already mitigates this requirement:
209+
210+
``INSERT`` statements now support ``SELECT`` statements without parentheses.
211+
https://crate.io/docs/crate/reference/en/4.3/appendices/release-notes/1.0.1.html
193212
"""
194213

214+
# Only CrateDB <= 1.0.0 needs parentheses for ``INSERT INTO ... SELECT ...``.
215+
if self.dialect.server_version_info >= INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION:
216+
return super(CrateCompiler, self).visit_insert(insert_stmt, asfrom=asfrom, **kw)
217+
218+
if SA_VERSION >= SA_1_4:
219+
raise DeprecationWarning(
220+
"CrateDB version < {} not supported with SQLAlchemy 1.4".format(
221+
INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION))
222+
195223
self.stack.append(
196224
{'correlate_froms': set(),
197225
"asfrom_froms": set(),
@@ -288,6 +316,9 @@ def visit_update(self, update_stmt, **kw):
288316
Parts are taken from the SQLCompiler base class.
289317
"""
290318

319+
if SA_VERSION >= SA_1_4:
320+
return self.visit_update_14(update_stmt, **kw)
321+
291322
if not update_stmt.parameters and \
292323
not hasattr(update_stmt, '_crate_specific'):
293324
return super(CrateCompiler, self).visit_update(update_stmt, **kw)
@@ -311,11 +342,14 @@ def visit_update(self, update_stmt, **kw):
311342
update_stmt, table_text
312343
)
313344

314-
crud_params = self._get_crud_params(update_stmt, **kw)
345+
# CrateDB amendment.
346+
crud_params = self._get_crud_params(self, update_stmt, **kw)
315347

316348
text += table_text
317349

318350
text += ' SET '
351+
352+
# CrateDB amendment begin.
319353
include_table = extra_froms and \
320354
self.render_table_with_column_in_update_from
321355

@@ -333,6 +367,7 @@ def visit_update(self, update_stmt, **kw):
333367
set_clauses.append(k + ' = ' + self.process(bindparam))
334368

335369
text += ', '.join(set_clauses)
370+
# CrateDB amendment end.
336371

337372
if self.returning or update_stmt._returning:
338373
if not self.returning:
@@ -368,7 +403,6 @@ def visit_update(self, update_stmt, **kw):
368403

369404
def _get_crud_params(compiler, stmt, **kw):
370405
""" extract values from crud parameters
371-
372406
taken from SQLAlchemy's crud module (since 1.0.x) and
373407
adapted for Crate dialect"""
374408

@@ -428,3 +462,295 @@ def _get_crud_params(compiler, stmt, **kw):
428462
values, kw)
429463

430464
return values
465+
466+
def visit_update_14(self, update_stmt, **kw):
467+
468+
compile_state = update_stmt._compile_state_factory(
469+
update_stmt, self, **kw
470+
)
471+
update_stmt = compile_state.statement
472+
473+
toplevel = not self.stack
474+
if toplevel:
475+
self.isupdate = True
476+
if not self.compile_state:
477+
self.compile_state = compile_state
478+
479+
extra_froms = compile_state._extra_froms
480+
is_multitable = bool(extra_froms)
481+
482+
if is_multitable:
483+
# main table might be a JOIN
484+
main_froms = set(selectable._from_objects(update_stmt.table))
485+
render_extra_froms = [
486+
f for f in extra_froms if f not in main_froms
487+
]
488+
correlate_froms = main_froms.union(extra_froms)
489+
else:
490+
render_extra_froms = []
491+
correlate_froms = {update_stmt.table}
492+
493+
self.stack.append(
494+
{
495+
"correlate_froms": correlate_froms,
496+
"asfrom_froms": correlate_froms,
497+
"selectable": update_stmt,
498+
}
499+
)
500+
501+
text = "UPDATE "
502+
503+
if update_stmt._prefixes:
504+
text += self._generate_prefixes(
505+
update_stmt, update_stmt._prefixes, **kw
506+
)
507+
508+
table_text = self.update_tables_clause(
509+
update_stmt, update_stmt.table, render_extra_froms, **kw
510+
)
511+
512+
# CrateDB amendment.
513+
crud_params = _get_crud_params_14(
514+
self, update_stmt, compile_state, **kw
515+
)
516+
517+
if update_stmt._hints:
518+
dialect_hints, table_text = self._setup_crud_hints(
519+
update_stmt, table_text
520+
)
521+
else:
522+
dialect_hints = None
523+
524+
text += table_text
525+
526+
text += " SET "
527+
528+
# CrateDB amendment begin.
529+
include_table = extra_froms and \
530+
self.render_table_with_column_in_update_from
531+
532+
set_clauses = []
533+
534+
for c, expr, value in crud_params:
535+
key = c._compiler_dispatch(self, include_table=include_table)
536+
clause = key + ' = ' + value
537+
set_clauses.append(clause)
538+
539+
for k, v in compile_state._dict_parameters.items():
540+
if isinstance(k, str) and '[' in k:
541+
bindparam = sa.sql.bindparam(k, v)
542+
clause = k + ' = ' + self.process(bindparam)
543+
set_clauses.append(clause)
544+
545+
text += ', '.join(set_clauses)
546+
# CrateDB amendment end.
547+
548+
if self.returning or update_stmt._returning:
549+
if self.returning_precedes_values:
550+
text += " " + self.returning_clause(
551+
update_stmt, self.returning or update_stmt._returning
552+
)
553+
554+
if extra_froms:
555+
extra_from_text = self.update_from_clause(
556+
update_stmt,
557+
update_stmt.table,
558+
render_extra_froms,
559+
dialect_hints,
560+
**kw
561+
)
562+
if extra_from_text:
563+
text += " " + extra_from_text
564+
565+
if update_stmt._where_criteria:
566+
t = self._generate_delimited_and_list(
567+
update_stmt._where_criteria, **kw
568+
)
569+
if t:
570+
text += " WHERE " + t
571+
572+
limit_clause = self.update_limit_clause(update_stmt)
573+
if limit_clause:
574+
text += " " + limit_clause
575+
576+
if (
577+
self.returning or update_stmt._returning
578+
) and not self.returning_precedes_values:
579+
text += " " + self.returning_clause(
580+
update_stmt, self.returning or update_stmt._returning
581+
)
582+
583+
if self.ctes and toplevel:
584+
text = self._render_cte_clause() + text
585+
586+
self.stack.pop(-1)
587+
588+
return text
589+
590+
591+
def _get_crud_params_14(compiler, stmt, compile_state, **kw):
592+
"""create a set of tuples representing column/string pairs for use
593+
in an INSERT or UPDATE statement.
594+
595+
Also generates the Compiled object's postfetch, prefetch, and
596+
returning column collections, used for default handling and ultimately
597+
populating the CursorResult's prefetch_cols() and postfetch_cols()
598+
collections.
599+
600+
"""
601+
from sqlalchemy.sql.crud import _key_getters_for_crud_column
602+
from sqlalchemy.sql.crud import _create_bind_param
603+
from sqlalchemy.sql.crud import REQUIRED
604+
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
605+
from sqlalchemy.sql.crud import _get_multitable_params
606+
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
607+
from sqlalchemy.sql.crud import _scan_cols
608+
from sqlalchemy import exc
609+
from sqlalchemy.sql.crud import _extend_values_for_multiparams
610+
611+
compiler.postfetch = []
612+
compiler.insert_prefetch = []
613+
compiler.update_prefetch = []
614+
compiler.returning = []
615+
616+
# getters - these are normally just column.key,
617+
# but in the case of mysql multi-table update, the rules for
618+
# .key must conditionally take tablename into account
619+
(
620+
_column_as_key,
621+
_getattr_col_key,
622+
_col_bind_name,
623+
) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
624+
625+
compiler._key_getters_for_crud_column = getters
626+
627+
# no parameters in the statement, no parameters in the
628+
# compiled params - return binds for all columns
629+
if compiler.column_keys is None and compile_state._no_parameters:
630+
return [
631+
(
632+
c,
633+
compiler.preparer.format_column(c),
634+
_create_bind_param(compiler, c, None, required=True),
635+
)
636+
for c in stmt.table.columns
637+
]
638+
639+
if compile_state._has_multi_parameters:
640+
spd = compile_state._multi_parameters[0]
641+
stmt_parameter_tuples = list(spd.items())
642+
elif compile_state._ordered_values:
643+
spd = compile_state._dict_parameters
644+
stmt_parameter_tuples = compile_state._ordered_values
645+
elif compile_state._dict_parameters:
646+
spd = compile_state._dict_parameters
647+
stmt_parameter_tuples = list(spd.items())
648+
else:
649+
stmt_parameter_tuples = spd = None
650+
651+
# if we have statement parameters - set defaults in the
652+
# compiled params
653+
if compiler.column_keys is None:
654+
parameters = {}
655+
elif stmt_parameter_tuples:
656+
parameters = dict(
657+
(_column_as_key(key), REQUIRED)
658+
for key in compiler.column_keys
659+
if key not in spd
660+
)
661+
else:
662+
parameters = dict(
663+
(_column_as_key(key), REQUIRED) for key in compiler.column_keys
664+
)
665+
666+
# create a list of column assignment clauses as tuples
667+
values = []
668+
669+
if stmt_parameter_tuples is not None:
670+
_get_stmt_parameter_tuples_params(
671+
compiler,
672+
compile_state,
673+
parameters,
674+
stmt_parameter_tuples,
675+
_column_as_key,
676+
values,
677+
kw,
678+
)
679+
680+
check_columns = {}
681+
682+
# special logic that only occurs for multi-table UPDATE
683+
# statements
684+
if compile_state.isupdate and compile_state.is_multitable:
685+
_get_multitable_params(
686+
compiler,
687+
stmt,
688+
compile_state,
689+
stmt_parameter_tuples,
690+
check_columns,
691+
_col_bind_name,
692+
_getattr_col_key,
693+
values,
694+
kw,
695+
)
696+
697+
if compile_state.isinsert and stmt._select_names:
698+
_scan_insert_from_select_cols(
699+
compiler,
700+
stmt,
701+
compile_state,
702+
parameters,
703+
_getattr_col_key,
704+
_column_as_key,
705+
_col_bind_name,
706+
check_columns,
707+
values,
708+
kw,
709+
)
710+
else:
711+
_scan_cols(
712+
compiler,
713+
stmt,
714+
compile_state,
715+
parameters,
716+
_getattr_col_key,
717+
_column_as_key,
718+
_col_bind_name,
719+
check_columns,
720+
values,
721+
kw,
722+
)
723+
724+
# CrateDB amendment.
725+
"""
726+
if parameters and stmt_parameter_tuples:
727+
check = (
728+
set(parameters)
729+
.intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
730+
.difference(check_columns)
731+
)
732+
if check:
733+
raise exc.CompileError(
734+
"Unconsumed column names: %s"
735+
% (", ".join("%s" % (c,) for c in check))
736+
)
737+
"""
738+
739+
if compile_state._has_multi_parameters:
740+
values = _extend_values_for_multiparams(
741+
compiler, stmt, compile_state, values, kw
742+
)
743+
elif not values and compiler.for_executemany:
744+
# convert an "INSERT DEFAULT VALUES"
745+
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
746+
# into an in-place multi values. This supports
747+
# insert_executemany_returning mode :)
748+
values = [
749+
(
750+
stmt.table.columns[0],
751+
compiler.preparer.format_column(stmt.table.columns[0]),
752+
"DEFAULT",
753+
)
754+
]
755+
756+
return values

0 commit comments

Comments
 (0)