Skip to content

Commit

Permalink
SA14: Adjust CrateDB dialect compiler patch to SqlAlchemy 1.4.36
Browse files Browse the repository at this point in the history
The original code for `visit_update` and `_get_crud_params` from
SQLAlchemy 1.4.0b1 has been vendored into the CrateDB dialect the other day,
in order to amend it due to dialect-specific purposes.

This patch reflects the changes from SA 1.4.0b1 to SA 1.4.36 on this code.
  • Loading branch information
amotl committed May 25, 2022
1 parent 62fd182 commit bcf9fff
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 10 deletions.
7 changes: 7 additions & 0 deletions lgtm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
queries:

# Suppress some LGTM warnings.

# A module is imported with the "import" and "import from" statements.
# https://lgtm.com/rules/1818040193/
- exclude: py/import-and-import-from
29 changes: 21 additions & 8 deletions src/crate/client/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from collections import defaultdict

import sqlalchemy as sa
from sqlalchemy.sql import crud, selectable
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, crud, selectable
from .types import MutableDict
from .sa_version import SA_VERSION, SA_1_4

Expand Down Expand Up @@ -400,6 +399,10 @@ def visit_update_14(self, update_stmt, **kw):
else:
dialect_hints = None

if update_stmt._independent_ctes:
for cte in update_stmt._independent_ctes:
cte._compiler_dispatch(self, **kw)

text += table_text

text += " SET "
Expand Down Expand Up @@ -459,8 +462,9 @@ def visit_update_14(self, update_stmt, **kw):
update_stmt, self.returning or update_stmt._returning
)

if self.ctes and toplevel:
text = self._render_cte_clause() + text
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text

self.stack.pop(-1)

Expand All @@ -481,7 +485,7 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):
from sqlalchemy.sql.crud import _create_bind_param
from sqlalchemy.sql.crud import REQUIRED
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
from sqlalchemy.sql.crud import _get_multitable_params
from sqlalchemy.sql.crud import _get_update_multitable_params
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
from sqlalchemy.sql.crud import _scan_cols
from sqlalchemy import exc # noqa: F401
Expand Down Expand Up @@ -561,7 +565,7 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):
# special logic that only occurs for multi-table UPDATE
# statements
if compile_state.isupdate and compile_state.is_multitable:
_get_multitable_params(
_get_update_multitable_params(
compiler,
stmt,
compile_state,
Expand Down Expand Up @@ -620,9 +624,18 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):

if compile_state._has_multi_parameters:
values = _extend_values_for_multiparams(
compiler, stmt, compile_state, values, kw
compiler,
stmt,
compile_state,
values,
_column_as_key,
kw,
)
elif not values and compiler.for_executemany:
elif (
not values
and compiler.for_executemany # noqa: W503
and compiler.dialect.supports_default_metavalue # noqa: W503
):
# convert an "INSERT DEFAULT VALUES"
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
# into an in-place multi values. This supports
Expand Down
2 changes: 1 addition & 1 deletion src/crate/client/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# FIXME: Workaround to be able to use SQLAlchemy 1.4.
# Caveat: This purges the ``cresultproxy`` extension
# at runtime, so it will impose a speed bump.
import crate.client.sqlalchemy.monkey # noqa:F401
import crate.client.sqlalchemy.monkey # noqa:F401, lgtm[py/unused-import]

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
Expand Down
2 changes: 1 addition & 1 deletion src/crate/client/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class Any(expression.ColumnElement):

def __init__(self, left, right, operator=operators.eq):
self.type = sqltypes.Boolean()
self.left = expression._literal_as_binds(left)
self.left = expression.literal(left)
self.right = right
self.operator = operator

Expand Down

0 comments on commit bcf9fff

Please sign in to comment.