Skip to content

Commit 0b9b005

Browse files
DDL execution method for SQLA2
1 parent f4fb5b0 commit 0b9b005

File tree

5 files changed

+43
-11
lines changed

5 files changed

+43
-11
lines changed

cardinal_pythonlib/sqlalchemy/dump.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,9 @@ def dump_orm_object_as_insert_sql(
265265
core_pkcol = table.columns.get(orm_pkcol.name)
266266
pkval = getattr(obj, orm_pkcol.name)
267267
query = query.where(core_pkcol == pkval)
268-
cursor = engine.execute(query)
269-
row = cursor.fetchone() # should only be one...
268+
with engine.begin() as connection:
269+
cursor = connection.execute(query)
270+
row = cursor.fetchone() # should only be one...
270271
row_dict = dict(row)
271272
statement = table.insert(values=row_dict)
272273
insert_str = get_literal_query(statement, bind=engine)

cardinal_pythonlib/sqlalchemy/engine_func.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def get_sqlserver_product_version(engine: "Engine") -> Tuple[int, ...]:
104104
"instances."
105105
)
106106
sql = "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
107-
rp = engine.execute(sql) # type: Result
108-
row = rp.fetchone()
107+
with engine.begin() as connection:
108+
rp = connection.execute(sql) # type: Result
109+
row = rp.fetchone()
109110
dotted_version = row[0] # type: str # e.g. '12.0.5203.0'
110111
return tuple(int(x) for x in dotted_version.split("."))
111112

cardinal_pythonlib/sqlalchemy/orm_schema.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
import logging
3030
from typing import Type, TYPE_CHECKING
3131

32-
from cardinal_pythonlib.sqlalchemy.session import get_safe_url_from_engine
3332
from sqlalchemy.engine.base import Engine
3433
from sqlalchemy.orm import DeclarativeMeta
3534
from sqlalchemy.schema import CreateTable
3635

36+
from cardinal_pythonlib.sqlalchemy.schema import execute_ddl
37+
from cardinal_pythonlib.sqlalchemy.session import get_safe_url_from_engine
38+
3739
if TYPE_CHECKING:
3840
from sqlalchemy.sql.schema import Table
3941

@@ -73,5 +75,4 @@ def create_table_from_orm_class(
7375
creator = CreateTable(
7476
table, include_foreign_key_constraints=include_foreign_key_constraints
7577
)
76-
with engine.begin() as conn: # though DML/DDL doesn't really need a COMMIT
77-
conn.execute(creator)
78+
execute_ddl(engine, ddl=creator)

cardinal_pythonlib/sqlalchemy/schema.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
Table,
6060
)
6161
from sqlalchemy.sql import sqltypes, text
62+
from sqlalchemy.sql.ddl import DDLElement
6263
from sqlalchemy.sql.sqltypes import BigInteger, TypeEngine
6364
from sqlalchemy.sql.visitors import Visitable
6465

@@ -337,6 +338,33 @@ def get_effective_int_pk_col(table_: Table) -> Optional[str]:
337338
)
338339

339340

341+
# =============================================================================
342+
# Execute DDL
343+
# =============================================================================
344+
345+
346+
def execute_ddl(
347+
engine: Engine, sql: str = None, ddl: DDLElement = None
348+
) -> None:
349+
"""
350+
Execute DDL, either from a plain SQL string, or from an SQLAlchemy DDL
351+
element.
352+
353+
Previously we would use DDL(sql, bind=engine).execute(), but this has gone
354+
in SQLAlchemy 2.0.
355+
356+
If you want dialect-conditional execution, create the DDL object with e.g.
357+
ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER), and pass that
358+
DDL object to this function.
359+
"""
360+
assert bool(sql) ^ bool(ddl) # one or the other.
361+
if sql:
362+
ddl = DDL(sql)
363+
with engine.connect() as connection:
364+
# DDL doesn't need a COMMIT.
365+
connection.execute(ddl)
366+
367+
340368
# =============================================================================
341369
# Indexes
342370
# =============================================================================
@@ -565,8 +593,7 @@ def quote(identifier: str) -> str:
565593
colnames=", ".join(quote(c) for c in colnames),
566594
)
567595
)
568-
# DDL(sql, bind=engine).execute_if(dialect=SqlaDialectName.MYSQL)
569-
DDL(sql, bind=engine).execute()
596+
execute_ddl(engine, sql=sql)
570597

571598
elif is_mssql: # Microsoft SQL Server
572599
# https://msdn.microsoft.com/library/ms187317(SQL.130).aspx
@@ -628,7 +655,7 @@ def quote(identifier: str) -> str:
628655
)
629656
# Executing serial COMMITs or a ROLLBACK won't help here if
630657
# this transaction is due to Python DBAPI default behaviour.
631-
DDL(sql, bind=engine).execute()
658+
execute_ddl(engine, sql=sql)
632659

633660
# The reversal procedure is DROP FULLTEXT INDEX ON tablename;
634661

cardinal_pythonlib/sqlalchemy/sqlserver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
quote_identifier,
3737
SqlaDialectName,
3838
)
39+
from cardinal_pythonlib.sqlalchemy.schema import execute_ddl
3940
from cardinal_pythonlib.sqlalchemy.session import get_engine_from_session
4041

4142

@@ -48,7 +49,8 @@ def _exec_ddl_if_sqlserver(engine: Engine, sql: str) -> None:
4849
"""
4950
Execute DDL only if we are running on Microsoft SQL Server.
5051
"""
51-
DDL(sql, bind=engine).execute_if(dialect=SqlaDialectName.SQLSERVER)
52+
ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER)
53+
execute_ddl(engine, ddl=ddl)
5254

5355

5456
@contextmanager

0 commit comments

Comments
 (0)