Skip to content

Commit 5934726

Browse files
mathiasritterMathias Ritter
authored andcommitted
SQLAlchemy 2.0 Compatibility
In the do_execute function of SQLAlchemy dialect, an if statement was accessing context.should_autocommit. This property has been removed in SQLAlchemy 2.0. To fix, the if statement can simply be removed as it is no longer needed. SQLAlchemy 2.0 also removed support for parameters as keyword arguments in the Connection.execute function. Instead, parameters need to be passed in a dictionary.
1 parent e4a3f0f commit 5934726

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ jobs:
5858
- { python: "3.11", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
5959
# Test with sqlalchemy 1.3
6060
- { python: "3.11", trino: "latest", sqlalchemy: "~=1.3.0" }
61+
# Test with sqlalchemy 2.0
62+
- { python: "3.11", trino: "latest", sqlalchemy: "~=2.0.0rc1" }
6163
env:
6264
TRINO_VERSION: "${{ matrix.trino }}"
6365
steps:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
version = str(ast.literal_eval(trino_version.group(1)))
2727

2828
kerberos_require = ["requests_kerberos"]
29-
sqlalchemy_require = ["sqlalchemy~=1.3"]
29+
sqlalchemy_require = ["sqlalchemy >= 1.3"]
3030
external_authentication_token_cache_require = ["keyring"]
3131

3232
# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
@@ -78,7 +78,7 @@
7878
"Programming Language :: Python :: Implementation :: PyPy",
7979
"Topic :: Database :: Front-Ends",
8080
],
81-
python_requires='>=3.7',
81+
python_requires=">=3.7",
8282
install_requires=["pytz", "requests", "tzlocal"],
8383
extras_require={
8484
"all": all_require,

trino/sqlalchemy/dialect.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
156156
ORDER BY "ordinal_position" ASC
157157
"""
158158
).strip()
159-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
159+
res = connection.execute(sql.text(query), { "schema": schema, "table": table_name })
160160
columns = []
161161
for record in res:
162162
column = dict(
@@ -204,7 +204,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L
204204
AND "table_type" = 'BASE TABLE'
205205
"""
206206
).strip()
207-
res = connection.execute(sql.text(query), schema=schema)
207+
res = connection.execute(sql.text(query), { "schema": schema })
208208
return [row.table_name for row in res]
209209

210210
def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -225,7 +225,7 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li
225225
AND "table_type" = 'VIEW'
226226
"""
227227
).strip()
228-
res = connection.execute(sql.text(query), schema=schema)
228+
res = connection.execute(sql.text(query), { "schema": schema })
229229
return [row.table_name for row in res]
230230

231231
def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -244,7 +244,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
244244
AND "table_name" = :view
245245
"""
246246
).strip()
247-
res = connection.execute(sql.text(query), schema=schema, view=view_name)
247+
res = connection.execute(sql.text(query), { "schema": schema, "view": view_name })
248248
return res.scalar()
249249

250250
def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
@@ -314,7 +314,7 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
314314
WHERE "schema_name" = :schema
315315
"""
316316
).strip()
317-
res = connection.execute(sql.text(query), schema=schema)
317+
res = connection.execute(sql.text(query), { "schema": schema })
318318
return res.first() is not None
319319

320320
def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool:
@@ -329,7 +329,7 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None,
329329
AND "table_name" = :table
330330
"""
331331
).strip()
332-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
332+
res = connection.execute(sql.text(query), { "schema": schema, "table": table_name })
333333
return res.first() is not None
334334

335335
def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:
@@ -363,11 +363,6 @@ def do_execute(
363363
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
364364
):
365365
cursor.execute(statement, parameters)
366-
if context and context.should_autocommit:
367-
# SQL statement only submitted to Trino server when cursor.fetch*() is called.
368-
# For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description
369-
# to force submit statement immediately.
370-
cursor.description # noqa
371366

372367
def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
373368
if dbapi_connection.transaction is not None:

0 commit comments

Comments
 (0)