Skip to content

SQLAlchemy 2.0 Compatibility #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ jobs:
- { python: "3.11", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
# Test with sqlalchemy 1.3
- { python: "3.11", trino: "latest", sqlalchemy: "~=1.3.0" }
# Test with sqlalchemy 2.0
- { python: "3.11", trino: "latest", sqlalchemy: "~=2.0.0rc1" }
env:
TRINO_VERSION: "${{ matrix.trino }}"
steps:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
version = str(ast.literal_eval(trino_version.group(1)))

kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]
sqlalchemy_require = ["sqlalchemy >= 1.3"]
external_authentication_token_cache_require = ["keyring"]

# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
Expand Down
65 changes: 36 additions & 29 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def test_select_query(trino_connection):
rows = result.fetchall()
assert len(rows) == 25
for row in rows:
assert isinstance(row['nationkey'], int)
assert isinstance(row['name'], str)
assert isinstance(row['regionkey'], int)
assert isinstance(row['comment'], str)
assert isinstance(row.nationkey, int)
assert isinstance(row.name, str)
assert isinstance(row.regionkey, int)
assert isinstance(row.comment, str)


def assert_column(table, column_name, column_type):
Expand All @@ -70,8 +70,8 @@ def test_select_specific_columns(trino_connection):
rows = result.fetchall()
assert len(rows) > 0
for row in rows:
assert isinstance(row['node_id'], str)
assert isinstance(row['state'], str)
assert isinstance(row.node_id, str)
assert isinstance(row.state, str)


@pytest.mark.skipif(
Expand All @@ -82,7 +82,8 @@ def test_select_specific_columns(trino_connection):
def test_define_and_create_table(trino_connection):
engine, conn = trino_connection
if not engine.dialect.has_schema(conn, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()
try:
sqla.Table('users',
Expand Down Expand Up @@ -110,7 +111,8 @@ def test_insert(trino_connection):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()
try:
users = sqla.Table('users',
Expand Down Expand Up @@ -139,7 +141,8 @@ def test_insert(trino_connection):
def test_insert_multiple_statements(trino_connection):
engine, conn = trino_connection
if not engine.dialect.has_schema(conn, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()
users = sqla.Table('users',
metadata,
Expand Down Expand Up @@ -180,10 +183,10 @@ def test_operators(trino_connection):
rows = result.fetchall()
assert len(rows) == 1
for row in rows:
assert isinstance(row['nationkey'], int)
assert isinstance(row['name'], str)
assert isinstance(row['regionkey'], int)
assert isinstance(row['comment'], str)
assert isinstance(row.nationkey, int)
assert isinstance(row.name, str)
assert isinstance(row.regionkey, int)
assert isinstance(row.comment, str)


@pytest.mark.skipif(
Expand Down Expand Up @@ -216,14 +219,14 @@ def test_textual_sql(trino_connection):
rows = result.fetchall()
assert len(rows) == 3
for row in rows:
assert isinstance(row['custkey'], int)
assert isinstance(row['name'], str)
assert isinstance(row['address'], str)
assert isinstance(row['nationkey'], int)
assert isinstance(row['phone'], str)
assert isinstance(row['acctbal'], float)
assert isinstance(row['mktsegment'], str)
assert isinstance(row['comment'], str)
assert isinstance(row.custkey, int)
assert isinstance(row.name, str)
assert isinstance(row.address, str)
assert isinstance(row.nationkey, int)
assert isinstance(row.phone, str)
assert isinstance(row.acctbal, float)
assert isinstance(row.mktsegment, str)
assert isinstance(row.comment, str)


@pytest.mark.skipif(
Expand Down Expand Up @@ -323,7 +326,8 @@ def test_json_column(trino_connection, json_object):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
Expand Down Expand Up @@ -351,7 +355,8 @@ def test_get_table_comment(trino_connection):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
with engine.begin() as connection:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required for sqlalchemy 2.0.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
Expand All @@ -378,7 +383,8 @@ def test_get_table_names(trino_connection, schema):
metadata = sqla.MetaData(schema=schema_name)

if not engine.dialect.has_schema(conn, schema_name):
engine.execute(sqla.schema.CreateSchema(schema_name))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema(schema_name))

try:
sqla.Table(
Expand All @@ -388,10 +394,10 @@ def test_get_table_names(trino_connection, schema):
)
metadata.create_all(engine)
view_name = schema_name + ".test_view"
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names")
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names"))
assert sqla.inspect(engine).get_table_names(schema_name) == ['test_get_table_names']
finally:
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
metadata.drop_all(engine)


Expand All @@ -411,7 +417,8 @@ def test_get_view_names(trino_connection, schema):
metadata = sqla.MetaData(schema=schema_name)

if not engine.dialect.has_schema(conn, schema_name):
engine.execute(sqla.schema.CreateSchema(schema_name))
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema(schema_name))

try:
sqla.Table(
Expand All @@ -421,10 +428,10 @@ def test_get_view_names(trino_connection, schema):
)
metadata.create_all(engine)
view_name = schema_name + ".test_get_view_names"
conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_table")
conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_table"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conn object is a dbapi connection (see fixture trino_connection) so why do we need this sqla.text here?

Copy link
Member Author

@mathiasritter mathiasritter Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conn object is a SQLAlchemy Connection instance:

@pytest.fixture
def trino_connection(run_trino, request):
_, host, port = run_trino
engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}",
connect_args={"source": "test", "max_attempts": 1})
yield engine, engine.connect()

SQLAlchemy 2.0 requires this raw statement to be wrapped into text.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, 🤦 - should've seen what class the fixture is coming from. Thanks.

assert sqla.inspect(engine).get_view_names(schema_name) == ['test_get_view_names']
finally:
conn.execute(f"DROP VIEW IF EXISTS {view_name}")
conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}"))
metadata.drop_all(engine)


Expand Down
19 changes: 7 additions & 12 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
ORDER BY "ordinal_position" ASC
"""
).strip()
res = connection.execute(sql.text(query), schema=schema, table=table_name)
res = connection.execute(sql.text(query), {"schema": schema, "table": table_name})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required for sqlalchemy 2.0.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

columns = []
for record in res:
column = dict(
Expand Down Expand Up @@ -204,7 +204,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L
AND "table_type" = 'BASE TABLE'
"""
).strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), {"schema": schema})
return [row.table_name for row in res]

def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
Expand All @@ -225,7 +225,7 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li
AND "table_type" = 'VIEW'
"""
).strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), {"schema": schema})
return [row.table_name for row in res]

def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
Expand All @@ -244,7 +244,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
AND "table_name" = :view
"""
).strip()
res = connection.execute(sql.text(query), schema=schema, view=view_name)
res = connection.execute(sql.text(query), {"schema": schema, "view": view_name})
return res.scalar()

def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -296,7 +296,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
try:
res = connection.execute(
sql.text(query),
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
{"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name}
)
return dict(text=res.scalar())
except error.TrinoQueryError as e:
Expand All @@ -314,7 +314,7 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
WHERE "schema_name" = :schema
"""
).strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), {"schema": schema})
return res.first() is not None

def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool:
Expand All @@ -329,7 +329,7 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None,
AND "table_name" = :table
"""
).strip()
res = connection.execute(sql.text(query), schema=schema, table=table_name)
res = connection.execute(sql.text(query), {"schema": schema, "table": table_name})
return res.first() is not None

def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:
Expand Down Expand Up @@ -363,11 +363,6 @@ def do_execute(
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
):
cursor.execute(statement, parameters)
if context and context.should_autocommit:
# SQL statement only submitted to Trino server when cursor.fetch*() is called.
# For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description
# to force submit statement immediately.
cursor.description # noqa

def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
if dbapi_connection.transaction is not None:
Expand Down