Skip to content

Commit

Permalink
Add SQLAlchemy 1.4 compatibility for CrateCompiler
Browse files Browse the repository at this point in the history
> A major initiative in the 1.4 series is to approach the model of both
Core SQL statements as well as the ORM Query to allow for an efficient,
cacheable model of statement creation and compilation, where the
compilation step would be cached, based on a cache key generated by the
created statement object, which itself is newly created for each use.
  • Loading branch information
amotl committed Dec 18, 2020
1 parent 9c3b47a commit 5d8a543
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 6 deletions.
338 changes: 332 additions & 6 deletions src/crate/client/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
from collections import defaultdict

import sqlalchemy as sa
from sqlalchemy.sql import crud
from sqlalchemy.sql import crud, selectable
from sqlalchemy.sql import compiler
from .types import MutableDict
from .sa_version import SA_1_1, SA_VERSION
from .sa_version import SA_VERSION, SA_1_1, SA_1_4


INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION = (1, 0, 1)


def rewrite_update(clauseelement, multiparams, params):
Expand Down Expand Up @@ -74,7 +77,18 @@ def rewrite_update(clauseelement, multiparams, params):
def crate_before_execute(conn, clauseelement, multiparams, params):
is_crate = type(conn.dialect).__name__ == 'CrateDialect'
if is_crate and isinstance(clauseelement, sa.sql.expression.Update):
return rewrite_update(clauseelement, multiparams, params)
if SA_VERSION >= SA_1_4:
multiparams = ([params],)
params = {}

clauseelement, multiparams, params = rewrite_update(clauseelement, multiparams, params)

if SA_VERSION >= SA_1_4:
params = multiparams[0]
multiparams = []

return clauseelement, multiparams, params

return clauseelement, multiparams, params


Expand Down Expand Up @@ -189,9 +203,23 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
used to compile <sql.expression.Insert> expressions.
this function wraps insert_from_select statements inside
parentheses to be conform with earlier versions of CreateDB.
parentheses to be conform with earlier versions of CreateDB.
According to the changelog, CrateDB >= 1.0.1 already mitigates this requirement:
``INSERT`` statements now support ``SELECT`` statements without parentheses.
https://crate.io/docs/crate/reference/en/4.3/appendices/release-notes/1.0.1.html
"""

# Only CrateDB <= 1.0.0 needs parentheses for ``INSERT INTO ... SELECT ...``.
if self.dialect.server_version_info >= INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION:
return super(CrateCompiler, self).visit_insert(insert_stmt, asfrom=asfrom, **kw)

if SA_VERSION >= SA_1_4:
raise DeprecationWarning(
"CrateDB version < {} not supported with SQLAlchemy 1.4".format(
INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION))

self.stack.append(
{'correlate_froms': set(),
"asfrom_froms": set(),
Expand Down Expand Up @@ -288,6 +316,9 @@ def visit_update(self, update_stmt, **kw):
Parts are taken from the SQLCompiler base class.
"""

if SA_VERSION >= SA_1_4:
return self.visit_update_14(update_stmt, **kw)

if not update_stmt.parameters and \
not hasattr(update_stmt, '_crate_specific'):
return super(CrateCompiler, self).visit_update(update_stmt, **kw)
Expand All @@ -311,11 +342,14 @@ def visit_update(self, update_stmt, **kw):
update_stmt, table_text
)

crud_params = self._get_crud_params(update_stmt, **kw)
# CrateDB amendment.
crud_params = self._get_crud_params(self, update_stmt, **kw)

text += table_text

text += ' SET '

# CrateDB amendment begin.
include_table = extra_froms and \
self.render_table_with_column_in_update_from

Expand All @@ -333,6 +367,7 @@ def visit_update(self, update_stmt, **kw):
set_clauses.append(k + ' = ' + self.process(bindparam))

text += ', '.join(set_clauses)
# CrateDB amendment end.

if self.returning or update_stmt._returning:
if not self.returning:
Expand Down Expand Up @@ -368,7 +403,6 @@ def visit_update(self, update_stmt, **kw):

def _get_crud_params(compiler, stmt, **kw):
""" extract values from crud parameters
taken from SQLAlchemy's crud module (since 1.0.x) and
adapted for Crate dialect"""

Expand Down Expand Up @@ -428,3 +462,295 @@ def _get_crud_params(compiler, stmt, **kw):
values, kw)

return values

def visit_update_14(self, update_stmt, **kw):

compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
update_stmt = compile_state.statement

toplevel = not self.stack
if toplevel:
self.isupdate = True
if not self.compile_state:
self.compile_state = compile_state

extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)

if is_multitable:
# main table might be a JOIN
main_froms = set(selectable._from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
correlate_froms = main_froms.union(extra_froms)
else:
render_extra_froms = []
correlate_froms = {update_stmt.table}

self.stack.append(
{
"correlate_froms": correlate_froms,
"asfrom_froms": correlate_froms,
"selectable": update_stmt,
}
)

text = "UPDATE "

if update_stmt._prefixes:
text += self._generate_prefixes(
update_stmt, update_stmt._prefixes, **kw
)

table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)

# CrateDB amendment.
crud_params = _get_crud_params_14(
self, update_stmt, compile_state, **kw
)

if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
update_stmt, table_text
)
else:
dialect_hints = None

text += table_text

text += " SET "

# CrateDB amendment begin.
include_table = extra_froms and \
self.render_table_with_column_in_update_from

set_clauses = []

for c, expr, value in crud_params:
key = c._compiler_dispatch(self, include_table=include_table)
clause = key + ' = ' + value
set_clauses.append(clause)

for k, v in compile_state._dict_parameters.items():
if isinstance(k, str) and '[' in k:
bindparam = sa.sql.bindparam(k, v)
clause = k + ' = ' + self.process(bindparam)
set_clauses.append(clause)

text += ', '.join(set_clauses)
# CrateDB amendment end.

if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning
)

if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
dialect_hints,
**kw
)
if extra_from_text:
text += " " + extra_from_text

if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
update_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t

limit_clause = self.update_limit_clause(update_stmt)
if limit_clause:
text += " " + limit_clause

if (
self.returning or update_stmt._returning
) and not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning
)

if self.ctes and toplevel:
text = self._render_cte_clause() + text

self.stack.pop(-1)

return text


def _get_crud_params_14(compiler, stmt, compile_state, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
Also generates the Compiled object's postfetch, prefetch, and
returning column collections, used for default handling and ultimately
populating the CursorResult's prefetch_cols() and postfetch_cols()
collections.
"""
from sqlalchemy.sql.crud import _key_getters_for_crud_column
from sqlalchemy.sql.crud import _create_bind_param
from sqlalchemy.sql.crud import REQUIRED
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
from sqlalchemy.sql.crud import _get_multitable_params
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
from sqlalchemy.sql.crud import _scan_cols
from sqlalchemy import exc
from sqlalchemy.sql.crud import _extend_values_for_multiparams

compiler.postfetch = []
compiler.insert_prefetch = []
compiler.update_prefetch = []
compiler.returning = []

# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
(
_column_as_key,
_getattr_col_key,
_col_bind_name,
) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)

compiler._key_getters_for_crud_column = getters

# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
return [
(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
)
for c in stmt.table.columns
]

if compile_state._has_multi_parameters:
spd = compile_state._multi_parameters[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
stmt_parameter_tuples = compile_state._ordered_values
elif compile_state._dict_parameters:
spd = compile_state._dict_parameters
stmt_parameter_tuples = list(spd.items())
else:
stmt_parameter_tuples = spd = None

# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
if key not in spd
)
else:
parameters = dict(
(_column_as_key(key), REQUIRED) for key in compiler.column_keys
)

# create a list of column assignment clauses as tuples
values = []

if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
compiler,
compile_state,
parameters,
stmt_parameter_tuples,
_column_as_key,
values,
kw,
)

check_columns = {}

# special logic that only occurs for multi-table UPDATE
# statements
if compile_state.isupdate and compile_state.is_multitable:
_get_multitable_params(
compiler,
stmt,
compile_state,
stmt_parameter_tuples,
check_columns,
_col_bind_name,
_getattr_col_key,
values,
kw,
)

if compile_state.isinsert and stmt._select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
compile_state,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)
else:
_scan_cols(
compiler,
stmt,
compile_state,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)

# CrateDB amendment.
"""
if parameters and stmt_parameter_tuples:
check = (
set(parameters)
.intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
.difference(check_columns)
)
if check:
raise exc.CompileError(
"Unconsumed column names: %s"
% (", ".join("%s" % (c,) for c in check))
)
"""

if compile_state._has_multi_parameters:
values = _extend_values_for_multiparams(
compiler, stmt, compile_state, values, kw
)
elif not values and compiler.for_executemany:
# convert an "INSERT DEFAULT VALUES"
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
# into an in-place multi values. This supports
# insert_executemany_returning mode :)
values = [
(
stmt.table.columns[0],
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
]

return values
Loading

0 comments on commit 5d8a543

Please sign in to comment.