Skip to content

Commit 0b1d8cb

Browse files
committed
Add minimal SQLAlchemy 1.4 compatibility for CrateCompiler
> A major initiative in the 1.4 series is to approach the model of both Core SQL statements as well as the ORM Query to allow for an efficient, cacheable model of statement creation and compilation, where the compilation step would be cached, based on a cache key generated by the created statement object, which itself is newly created for each use. While there is probably more to it, this patch tries to at least make the CrateCompiler work with the new subsystem infrastructure.
1 parent 8ba3b99 commit 0b1d8cb

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

src/crate/client/sqlalchemy/compiler.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sqlalchemy.sql import crud
2727
from sqlalchemy.sql import compiler
2828
from .types import MutableDict
29-
from .sa_version import SA_1_1, SA_VERSION
29+
from .sa_version import SA_1_1, SA_1_4, SA_VERSION
3030

3131

3232
def rewrite_update(clauseelement, multiparams, params):
@@ -198,7 +198,20 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
198198
"selectable": insert_stmt})
199199

200200
self.isinsert = True
201-
crud_params = crud._get_crud_params(self, insert_stmt, **kw)
201+
202+
if SA_VERSION >= SA_1_4:
203+
# Minimal patch to be compatible with SQLAlchemy 1.4.
204+
# For a more thorough implementation, please follow
205+
# https://github.com/crate/crate-python/pull/391.
206+
compile_state = insert_stmt._compile_state_factory(
207+
insert_stmt, self, **kw
208+
)
209+
insert_stmt = compile_state.statement
210+
crud_params = crud._get_crud_params(self, insert_stmt, compile_state, **kw)
211+
_has_multi_parameters = compile_state._has_multi_parameters
212+
else:
213+
crud_params = crud._get_crud_params(self, insert_stmt, **kw)
214+
_has_multi_parameters = insert_stmt._has_multi_parameters
202215

203216
if not crud_params and \
204217
not self.dialect.supports_default_values and \
@@ -207,7 +220,7 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
207220
"The '%s' dialect with current database version settings does "
208221
"not support empty inserts." % self.dialect.name)
209222

210-
if insert_stmt._has_multi_parameters:
223+
if _has_multi_parameters:
211224
if not self.dialect.supports_multivalues_insert:
212225
raise NotImplementedError(
213226
"The '%s' dialect with current database "
@@ -262,7 +275,7 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
262275
text += " (%s)" % self.process(self._insert_from_select, **kw)
263276
elif not crud_params and supports_default_values:
264277
text += " DEFAULT VALUES"
265-
elif insert_stmt._has_multi_parameters:
278+
elif _has_multi_parameters:
266279
text += " VALUES %s" % (
267280
", ".join(
268281
"(%s)" % (
@@ -294,7 +307,16 @@ def visit_update(self, update_stmt, **kw):
294307

295308
self.isupdate = True
296309

297-
extra_froms = update_stmt._extra_froms
310+
if SA_VERSION >= SA_1_4:
311+
# Minimal patch to be compatible with SQLAlchemy 1.4.
312+
# For a more thorough implementation, please follow
313+
# https://github.com/crate/crate-python/pull/391.
314+
compile_state = update_stmt._compile_state_factory(
315+
update_stmt, compiler, **kw
316+
)
317+
extra_froms = compile_state._extra_froms
318+
else:
319+
extra_froms = update_stmt._extra_froms
298320

299321
text = 'UPDATE '
300322

@@ -351,10 +373,21 @@ def visit_update(self, update_stmt, **kw):
351373
if extra_from_text:
352374
text += " " + extra_from_text
353375

354-
if update_stmt._whereclause is not None:
355-
t = self.process(update_stmt._whereclause)
356-
if t:
357-
text += " WHERE " + t
376+
if SA_VERSION >= SA_1_4:
377+
# Minimal patch to be compatible with SQLAlchemy 1.4.
378+
# For a more thorough implementation, please follow
379+
# https://github.com/crate/crate-python/pull/391.
380+
if update_stmt._where_criteria:
381+
t = self._generate_delimited_and_list(
382+
update_stmt._where_criteria, **kw
383+
)
384+
if t:
385+
text += " WHERE " + t
386+
else:
387+
if update_stmt._whereclause is not None:
388+
t = self.process(update_stmt._whereclause)
389+
if t:
390+
text += " WHERE " + t
358391

359392
limit_clause = self.update_limit_clause(update_stmt)
360393
if limit_clause:
@@ -384,20 +417,33 @@ def _get_crud_params(compiler, stmt, **kw):
384417
required=True))
385418
for c in stmt.table.columns]
386419

387-
if stmt._has_multi_parameters:
388-
stmt_parameters = stmt.parameters[0]
389-
else:
390-
stmt_parameters = stmt.parameters
391-
392420
# getters - these are normally just column.key,
393421
# but in the case of mysql multi-table update, the rules for
394422
# .key must conditionally take tablename into account
395-
if SA_VERSION >= SA_1_1:
423+
if SA_VERSION >= SA_1_4:
424+
# Minimal patch to be compatible with SQLAlchemy 1.4.
425+
# For a more thorough implementation, please follow
426+
# https://github.com/crate/crate-python/pull/391.
427+
compile_state = stmt._compile_state_factory(
428+
stmt, compiler, **kw
429+
)
430+
stmt = compile_state.statement
431+
_column_as_key, _getattr_col_key, _col_bind_name = \
432+
crud._key_getters_for_crud_column(compiler, stmt, compile_state)
433+
_has_multi_parameters = compile_state._has_multi_parameters
434+
elif SA_VERSION >= SA_1_1:
396435
_column_as_key, _getattr_col_key, _col_bind_name = \
397436
crud._key_getters_for_crud_column(compiler, stmt)
437+
_has_multi_parameters = stmt._has_multi_parameters
398438
else:
399439
_column_as_key, _getattr_col_key, _col_bind_name = \
400440
crud._key_getters_for_crud_column(compiler)
441+
_has_multi_parameters = stmt._has_multi_parameters
442+
443+
if _has_multi_parameters:
444+
stmt_parameters = stmt.parameters[0]
445+
else:
446+
stmt_parameters = stmt.parameters
401447

402448
# if we have statement parameters - set defaults in the
403449
# compiled params

src/crate/client/sqlalchemy/sa_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
SA_VERSION = V(sa.__version__)
2626

2727
SA_1_1 = V('1.1a0')
28+
SA_1_4 = V('1.4.0b1')

0 commit comments

Comments
 (0)