Skip to content

Commit

Permalink
SA14: Add SQLAlchemy 1.4 compatibility for CrateCompiler
Browse files Browse the repository at this point in the history
- Code reuse was aimed at, but for the SA <1.4 vs. SA >=1.4 split, two
  functions, `visit_update_14` and `_get_crud_params_14`, have been vendored
  separately to accompany `crate.client.sqlalchemy.compiler.CrateCompiler`.
  All adjustments have now been marked inline with `CrateDB amendment`.

- The main query rewriting function for UPDATE statements, `rewrite_update`,
  needed adjustments to account for a different wrapping/nesting of in/out
  parameters.

- The `cresultproxy` module was temporarily taken out of the equation because
  it raised some runtime error.
  • Loading branch information
amotl committed May 25, 2022
1 parent 1d9e151 commit 62fd182
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Unreleased
- Added support for enabling SSL using SQLAlchemy DB URI with parameter
``?ssl=true``.

- Add support for SQLAlchemy 1.4

2020/09/28 0.26.0
=================

Expand Down
317 changes: 314 additions & 3 deletions src/crate/client/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from collections import defaultdict

import sqlalchemy as sa
from sqlalchemy.sql import crud
from sqlalchemy.sql import crud, selectable
from sqlalchemy.sql import compiler
from .types import MutableDict
from .sa_version import SA_VERSION, SA_1_4


def rewrite_update(clauseelement, multiparams, params):
Expand Down Expand Up @@ -73,7 +74,16 @@ def rewrite_update(clauseelement, multiparams, params):
def crate_before_execute(conn, clauseelement, multiparams, params):
is_crate = type(conn.dialect).__name__ == 'CrateDialect'
if is_crate and isinstance(clauseelement, sa.sql.expression.Update):
return rewrite_update(clauseelement, multiparams, params)
if SA_VERSION >= SA_1_4:
multiparams = ([params],)
params = {}

clauseelement, multiparams, params = rewrite_update(clauseelement, multiparams, params)

if SA_VERSION >= SA_1_4:
params = multiparams[0]
multiparams = []

return clauseelement, multiparams, params


Expand Down Expand Up @@ -189,6 +199,9 @@ def visit_update(self, update_stmt, **kw):
Parts are taken from the SQLCompiler base class.
"""

if SA_VERSION >= SA_1_4:
return self.visit_update_14(update_stmt, **kw)

if not update_stmt.parameters and \
not hasattr(update_stmt, '_crate_specific'):
return super(CrateCompiler, self).visit_update(update_stmt, **kw)
Expand All @@ -212,11 +225,14 @@ def visit_update(self, update_stmt, **kw):
update_stmt, table_text
)

# CrateDB amendment.
crud_params = self._get_crud_params(update_stmt, **kw)

text += table_text

text += ' SET '

# CrateDB amendment begin.
include_table = extra_froms and \
self.render_table_with_column_in_update_from

Expand All @@ -234,6 +250,7 @@ def visit_update(self, update_stmt, **kw):
set_clauses.append(k + ' = ' + self.process(bindparam))

text += ', '.join(set_clauses)
# CrateDB amendment end.

if self.returning or update_stmt._returning:
if not self.returning:
Expand Down Expand Up @@ -269,7 +286,6 @@ def visit_update(self, update_stmt, **kw):

def _get_crud_params(compiler, stmt, **kw):
""" extract values from crud parameters
taken from SQLAlchemy's crud module (since 1.0.x) and
adapted for Crate dialect"""

Expand Down Expand Up @@ -325,3 +341,298 @@ def _get_crud_params(compiler, stmt, **kw):
values, kw)

return values

def visit_update_14(self, update_stmt, **kw):

compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
update_stmt = compile_state.statement

toplevel = not self.stack
if toplevel:
self.isupdate = True
if not self.compile_state:
self.compile_state = compile_state

extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)

if is_multitable:
# main table might be a JOIN
main_froms = set(selectable._from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
correlate_froms = main_froms.union(extra_froms)
else:
render_extra_froms = []
correlate_froms = {update_stmt.table}

self.stack.append(
{
"correlate_froms": correlate_froms,
"asfrom_froms": correlate_froms,
"selectable": update_stmt,
}
)

text = "UPDATE "

if update_stmt._prefixes:
text += self._generate_prefixes(
update_stmt, update_stmt._prefixes, **kw
)

table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)

# CrateDB amendment.
crud_params = _get_crud_params_14(
self, update_stmt, compile_state, **kw
)

if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
update_stmt, table_text
)
else:
dialect_hints = None

text += table_text

text += " SET "

# CrateDB amendment begin.
include_table = extra_froms and \
self.render_table_with_column_in_update_from

set_clauses = []

for c, expr, value in crud_params:
key = c._compiler_dispatch(self, include_table=include_table)
clause = key + ' = ' + value
set_clauses.append(clause)

for k, v in compile_state._dict_parameters.items():
if isinstance(k, str) and '[' in k:
bindparam = sa.sql.bindparam(k, v)
clause = k + ' = ' + self.process(bindparam)
set_clauses.append(clause)

text += ', '.join(set_clauses)
# CrateDB amendment end.

if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning
)

if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
dialect_hints,
**kw
)
if extra_from_text:
text += " " + extra_from_text

if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
update_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t

limit_clause = self.update_limit_clause(update_stmt)
if limit_clause:
text += " " + limit_clause

if (
self.returning or update_stmt._returning
) and not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning
)

if self.ctes and toplevel:
text = self._render_cte_clause() + text

self.stack.pop(-1)

return text


def _get_crud_params_14(compiler, stmt, compile_state, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
Also generates the Compiled object's postfetch, prefetch, and
returning column collections, used for default handling and ultimately
populating the CursorResult's prefetch_cols() and postfetch_cols()
collections.
"""
from sqlalchemy.sql.crud import _key_getters_for_crud_column
from sqlalchemy.sql.crud import _create_bind_param
from sqlalchemy.sql.crud import REQUIRED
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
from sqlalchemy.sql.crud import _get_multitable_params
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
from sqlalchemy.sql.crud import _scan_cols
from sqlalchemy import exc # noqa: F401
from sqlalchemy.sql.crud import _extend_values_for_multiparams

compiler.postfetch = []
compiler.insert_prefetch = []
compiler.update_prefetch = []
compiler.returning = []

# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
(
_column_as_key,
_getattr_col_key,
_col_bind_name,
) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)

compiler._key_getters_for_crud_column = getters

# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
return [
(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
)
for c in stmt.table.columns
]

if compile_state._has_multi_parameters:
spd = compile_state._multi_parameters[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
stmt_parameter_tuples = compile_state._ordered_values
elif compile_state._dict_parameters:
spd = compile_state._dict_parameters
stmt_parameter_tuples = list(spd.items())
else:
stmt_parameter_tuples = spd = None

# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
if key not in spd
)
else:
parameters = dict(
(_column_as_key(key), REQUIRED) for key in compiler.column_keys
)

# create a list of column assignment clauses as tuples
values = []

if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
compiler,
compile_state,
parameters,
stmt_parameter_tuples,
_column_as_key,
values,
kw,
)

check_columns = {}

# special logic that only occurs for multi-table UPDATE
# statements
if compile_state.isupdate and compile_state.is_multitable:
_get_multitable_params(
compiler,
stmt,
compile_state,
stmt_parameter_tuples,
check_columns,
_col_bind_name,
_getattr_col_key,
values,
kw,
)

if compile_state.isinsert and stmt._select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
compile_state,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)
else:
_scan_cols(
compiler,
stmt,
compile_state,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)

# CrateDB amendment.
# The rewriting logic in `rewrite_update` and `visit_update` needs
# adjustments here in order to prevent `sqlalchemy.exc.CompileError:
# Unconsumed column names: characters_name, data['nested']`
"""
if parameters and stmt_parameter_tuples:
check = (
set(parameters)
.intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
.difference(check_columns)
)
if check:
raise exc.CompileError(
"Unconsumed column names: %s"
% (", ".join("%s" % (c,) for c in check))
)
"""

if compile_state._has_multi_parameters:
values = _extend_values_for_multiparams(
compiler, stmt, compile_state, values, kw
)
elif not values and compiler.for_executemany:
# convert an "INSERT DEFAULT VALUES"
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
# into an in-place multi values. This supports
# insert_executemany_returning mode :)
values = [
(
stmt.table.columns[0],
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
]

return values
5 changes: 5 additions & 0 deletions src/crate/client/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
import logging
from datetime import datetime, date

# FIXME: Workaround to be able to use SQLAlchemy 1.4.
# Caveat: This purges the ``cresultproxy`` extension
# at runtime, so it will impose a speed bump.
import crate.client.sqlalchemy.monkey # noqa:F401

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
from sqlalchemy.sql import functions
Expand Down
Loading

0 comments on commit 62fd182

Please sign in to comment.