Skip to content

Commit 5e3fc8c

Browse files
mathiasritterMathias Ritter
authored andcommitted
SQLAlchemy 2.0 Compatibility
In the do_execute function of SQLAlchemy dialect, an if statement was accessing context.should_autocommit. This property has been removed in SQLAlchemy 2.0. To fix, the if statement can simply be removed as it is no longer needed. SQLAlchemy 2.0 also removed support for parameters as keyword arguments in the Connection.execute function. Instead, parameters need to be passed in a dictionary. Furthermore, Engine.execute was removed. Statements can only be executed on the Connection object, which can be obtained via Engine.begin() or Engine.connect(). Moreover, RowProxy is no longer a “proxy”; is now called Row and behaves like an enhanced named tuple. In order to access rows like a mapping, use Row._mapping. Raw SQL statements must be wrapped into a TextClause by calling text(...), imported from sqlalchemy.
1 parent e4a3f0f commit 5e3fc8c

File tree

4 files changed

+47
-43
lines changed

4 files changed

+47
-43
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ jobs:
5858
- { python: "3.11", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
5959
# Test with sqlalchemy 1.3
6060
- { python: "3.11", trino: "latest", sqlalchemy: "~=1.3.0" }
61+
# Test with sqlalchemy 2.0
62+
- { python: "3.11", trino: "latest", sqlalchemy: "~=2.0.0rc1" }
6163
env:
6264
TRINO_VERSION: "${{ matrix.trino }}"
6365
steps:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
version = str(ast.literal_eval(trino_version.group(1)))
2727

2828
kerberos_require = ["requests_kerberos"]
29-
sqlalchemy_require = ["sqlalchemy~=1.3"]
29+
sqlalchemy_require = ["sqlalchemy >= 1.3"]
3030
external_authentication_token_cache_require = ["keyring"]
3131

3232
# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
@@ -78,7 +78,7 @@
7878
"Programming Language :: Python :: Implementation :: PyPy",
7979
"Topic :: Database :: Front-Ends",
8080
],
81-
python_requires='>=3.7',
81+
python_requires=">=3.7",
8282
install_requires=["pytz", "requests", "tzlocal"],
8383
extras_require={
8484
"all": all_require,

tests/integration/test_sqlalchemy_integration.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def test_select_query(trino_connection):
4343
rows = result.fetchall()
4444
assert len(rows) == 25
4545
for row in rows:
46-
assert isinstance(row['nationkey'], int)
47-
assert isinstance(row['name'], str)
48-
assert isinstance(row['regionkey'], int)
49-
assert isinstance(row['comment'], str)
46+
assert isinstance(row._mapping['nationkey'] if hasattr(row, '_mapping') else row['nationkey'], int)
47+
assert isinstance(row._mapping['name'] if hasattr(row, '_mapping') else row['name'], str)
48+
assert isinstance(row._mapping['regionkey'] if hasattr(row, '_mapping') else row['regionkey'], int)
49+
assert isinstance(row._mapping['comment'] if hasattr(row, '_mapping') else row['comment'], str)
5050

5151

5252
def assert_column(table, column_name, column_type):
@@ -70,8 +70,8 @@ def test_select_specific_columns(trino_connection):
7070
rows = result.fetchall()
7171
assert len(rows) > 0
7272
for row in rows:
73-
assert isinstance(row['node_id'], str)
74-
assert isinstance(row['state'], str)
73+
assert isinstance(row._mapping['node_id'] if hasattr(row, '_mapping') else row['node_id'], str)
74+
assert isinstance(row._mapping['state'] if hasattr(row, '_mapping') else row['state'], str)
7575

7676

7777
@pytest.mark.skipif(
@@ -82,7 +82,8 @@ def test_select_specific_columns(trino_connection):
8282
def test_define_and_create_table(trino_connection):
8383
engine, conn = trino_connection
8484
if not engine.dialect.has_schema(conn, "test"):
85-
engine.execute(sqla.schema.CreateSchema("test"))
85+
with engine.begin() as connection:
86+
connection.execute(sqla.schema.CreateSchema("test"))
8687
metadata = sqla.MetaData()
8788
try:
8889
sqla.Table('users',
@@ -110,7 +111,8 @@ def test_insert(trino_connection):
110111
engine, conn = trino_connection
111112

112113
if not engine.dialect.has_schema(conn, "test"):
113-
engine.execute(sqla.schema.CreateSchema("test"))
114+
with engine.begin() as connection:
115+
connection.execute(sqla.schema.CreateSchema("test"))
114116
metadata = sqla.MetaData()
115117
try:
116118
users = sqla.Table('users',
@@ -139,7 +141,8 @@ def test_insert(trino_connection):
139141
def test_insert_multiple_statements(trino_connection):
140142
engine, conn = trino_connection
141143
if not engine.dialect.has_schema(conn, "test"):
142-
engine.execute(sqla.schema.CreateSchema("test"))
144+
with engine.begin() as connection:
145+
connection.execute(sqla.schema.CreateSchema("test"))
143146
metadata = sqla.MetaData()
144147
users = sqla.Table('users',
145148
metadata,
@@ -180,10 +183,10 @@ def test_operators(trino_connection):
180183
rows = result.fetchall()
181184
assert len(rows) == 1
182185
for row in rows:
183-
assert isinstance(row['nationkey'], int)
184-
assert isinstance(row['name'], str)
185-
assert isinstance(row['regionkey'], int)
186-
assert isinstance(row['comment'], str)
186+
assert isinstance(row._mapping['nationkey'] if hasattr(row, '_mapping') else row['nationkey'], int)
187+
assert isinstance(row._mapping['name'] if hasattr(row, '_mapping') else row['name'], str)
188+
assert isinstance(row._mapping['regionkey'] if hasattr(row, '_mapping') else row['regionkey'], int)
189+
assert isinstance(row._mapping['comment'] if hasattr(row, '_mapping') else row['comment'], str)
187190

188191

189192
@pytest.mark.skipif(
@@ -216,14 +219,14 @@ def test_textual_sql(trino_connection):
216219
rows = result.fetchall()
217220
assert len(rows) == 3
218221
for row in rows:
219-
assert isinstance(row['custkey'], int)
220-
assert isinstance(row['name'], str)
221-
assert isinstance(row['address'], str)
222-
assert isinstance(row['nationkey'], int)
223-
assert isinstance(row['phone'], str)
224-
assert isinstance(row['acctbal'], float)
225-
assert isinstance(row['mktsegment'], str)
226-
assert isinstance(row['comment'], str)
222+
assert isinstance(row._mapping['custkey'] if hasattr(row, '_mapping') else row['custkey'], int)
223+
assert isinstance(row._mapping['name'] if hasattr(row, '_mapping') else row['name'], str)
224+
assert isinstance(row._mapping['address'] if hasattr(row, '_mapping') else row['address'], str)
225+
assert isinstance(row._mapping['nationkey'] if hasattr(row, '_mapping') else row['nationkey'], int)
226+
assert isinstance(row._mapping['phone'] if hasattr(row, '_mapping') else row['phone'], str)
227+
assert isinstance(row._mapping['acctbal'] if hasattr(row, '_mapping') else row['acctbal'], float)
228+
assert isinstance(row._mapping['mktsegment'] if hasattr(row, '_mapping') else row['mktsegment'], str)
229+
assert isinstance(row._mapping['comment'] if hasattr(row, '_mapping') else row['comment'], str)
227230

228231

229232
@pytest.mark.skipif(
@@ -323,7 +326,8 @@ def test_json_column(trino_connection, json_object):
323326
engine, conn = trino_connection
324327

325328
if not engine.dialect.has_schema(conn, "test"):
326-
engine.execute(sqla.schema.CreateSchema("test"))
329+
with engine.begin() as connection:
330+
connection.execute(sqla.schema.CreateSchema("test"))
327331
metadata = sqla.MetaData()
328332

329333
try:
@@ -351,7 +355,8 @@ def test_get_table_comment(trino_connection):
351355
engine, conn = trino_connection
352356

353357
if not engine.dialect.has_schema(conn, "test"):
354-
engine.execute(sqla.schema.CreateSchema("test"))
358+
with engine.begin() as connection:
359+
connection.execute(sqla.schema.CreateSchema("test"))
355360
metadata = sqla.MetaData()
356361

357362
try:
@@ -378,7 +383,8 @@ def test_get_table_names(trino_connection, schema):
378383
metadata = sqla.MetaData(schema=schema_name)
379384

380385
if not engine.dialect.has_schema(conn, schema_name):
381-
engine.execute(sqla.schema.CreateSchema(schema_name))
386+
with engine.begin() as connection:
387+
connection.execute(sqla.schema.CreateSchema(schema_name))
382388

383389
try:
384390
sqla.Table(
@@ -388,10 +394,10 @@ def test_get_table_names(trino_connection, schema):
388394
)
389395
metadata.create_all(engine)
390396
view_name = schema_name + ".test_view"
391-
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names")
397+
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names"))
392398
assert sqla.inspect(engine).get_table_names(schema_name) == ['test_get_table_names']
393399
finally:
394-
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
400+
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
395401
metadata.drop_all(engine)
396402

397403

@@ -411,7 +417,8 @@ def test_get_view_names(trino_connection, schema):
411417
metadata = sqla.MetaData(schema=schema_name)
412418

413419
if not engine.dialect.has_schema(conn, schema_name):
414-
engine.execute(sqla.schema.CreateSchema(schema_name))
420+
with engine.begin() as connection:
421+
connection.execute(sqla.schema.CreateSchema(schema_name))
415422

416423
try:
417424
sqla.Table(
@@ -421,10 +428,10 @@ def test_get_view_names(trino_connection, schema):
421428
)
422429
metadata.create_all(engine)
423430
view_name = schema_name + ".test_get_view_names"
424-
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_table")
431+
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_table"))
425432
assert sqla.inspect(engine).get_view_names(schema_name) == ['test_get_view_names']
426433
finally:
427-
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
434+
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
428435
metadata.drop_all(engine)
429436

430437

trino/sqlalchemy/dialect.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
156156
ORDER BY "ordinal_position" ASC
157157
"""
158158
).strip()
159-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
159+
res = connection.execute(sql.text(query), { "schema": schema, "table": table_name })
160160
columns = []
161161
for record in res:
162162
column = dict(
@@ -204,7 +204,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L
204204
AND "table_type" = 'BASE TABLE'
205205
"""
206206
).strip()
207-
res = connection.execute(sql.text(query), schema=schema)
207+
res = connection.execute(sql.text(query), { "schema": schema })
208208
return [row.table_name for row in res]
209209

210210
def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -225,7 +225,7 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li
225225
AND "table_type" = 'VIEW'
226226
"""
227227
).strip()
228-
res = connection.execute(sql.text(query), schema=schema)
228+
res = connection.execute(sql.text(query), { "schema": schema })
229229
return [row.table_name for row in res]
230230

231231
def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -244,7 +244,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
244244
AND "table_name" = :view
245245
"""
246246
).strip()
247-
res = connection.execute(sql.text(query), schema=schema, view=view_name)
247+
res = connection.execute(sql.text(query), { "schema": schema, "view": view_name })
248248
return res.scalar()
249249

250250
def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
@@ -296,7 +296,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
296296
try:
297297
res = connection.execute(
298298
sql.text(query),
299-
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
299+
{ "catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name }
300300
)
301301
return dict(text=res.scalar())
302302
except error.TrinoQueryError as e:
@@ -314,7 +314,7 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
314314
WHERE "schema_name" = :schema
315315
"""
316316
).strip()
317-
res = connection.execute(sql.text(query), schema=schema)
317+
res = connection.execute(sql.text(query), { "schema": schema })
318318
return res.first() is not None
319319

320320
def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool:
@@ -329,7 +329,7 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None,
329329
AND "table_name" = :table
330330
"""
331331
).strip()
332-
res = connection.execute(sql.text(query), schema=schema, table=table_name)
332+
res = connection.execute(sql.text(query), { "schema": schema, "table": table_name })
333333
return res.first() is not None
334334

335335
def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:
@@ -363,11 +363,6 @@ def do_execute(
363363
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
364364
):
365365
cursor.execute(statement, parameters)
366-
if context and context.should_autocommit:
367-
# SQL statement only submitted to Trino server when cursor.fetch*() is called.
368-
# For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description
369-
# to force submit statement immediately.
370-
cursor.description # noqa
371366

372367
def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
373368
if dbapi_connection.transaction is not None:

0 commit comments

Comments
 (0)