Skip to content

Commit

Permalink
Merge pull request #163 from okorienev/sqlalchemy_2_0_support
Browse files Browse the repository at this point in the history
Add sqlalchemy>2 support
  • Loading branch information
kindermax authored Sep 13, 2024
2 parents 13a538b + 3f1c798 commit fe6be9e
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 52 deletions.
40 changes: 34 additions & 6 deletions hiku/sources/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
Callable,
Dict,
Iterable,
Mapping,
)

import sqlalchemy
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.expression import ColumnElement

from ..types import (
String,
Expand Down Expand Up @@ -43,6 +45,27 @@
from sqlalchemy.engine import RowProxy as Row


if SQLALCHEMY_VERSION >= (2, 0):

def _process_select_params(
params: List[ColumnElement],
) -> Iterable:
return params

def _process_result_row(row: Row) -> Mapping:
return row._mapping

else:

def _process_select_params(
params: List[ColumnElement],
) -> Iterable:
return (params,)

def _process_result_row(row: Row) -> Mapping:
return row


def _translate_type(
column: sqlalchemy.Column,
) -> Optional[Union[IntegerMeta, StringMeta]]:
Expand Down Expand Up @@ -114,14 +137,17 @@ def select_expr(
) -> Tuple[Select, Callable]:
columns = [self.from_clause.c[f.name] for f in fields_]
expr = (
sqlalchemy.select([self.primary_key] + columns)
sqlalchemy.select(
*_process_select_params([self.primary_key] + columns)
)
.select_from(self.from_clause)
.where(self.in_impl(self.primary_key, ids))
)

def result_proc(rows: List[Row]) -> List:
rows_map = {
row[self.primary_key]: [row[c] for c in columns] for row in rows
row[self.primary_key]: [row[c] for c in columns]
for row in map(_process_result_row, rows)
}

nulls = [None for _ in fields_]
Expand Down Expand Up @@ -210,10 +236,12 @@ def select_expr(self, ids: Iterable) -> Optional[Select]:
filtered_ids = [i for i in set(ids) if i is not None]
if filtered_ids:
return sqlalchemy.select(
[
self.from_column.label("from_column"),
self.to_column.label("to_column"),
]
*_process_select_params(
[
self.from_column.label("from_column"),
self.to_column.label("to_column"),
]
)
).where(self.in_impl(self.from_column, filtered_ids))
else:
return None
Expand Down
24 changes: 14 additions & 10 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,20 @@ def set_many(self, items: t.Dict[str, t.Any], ttl: int) -> None:

def setup_db(db_engine):
metadata.create_all(db_engine)
for r in [c._asdict() for c in DB["companies"].values()]:
db_engine.execute(company_table.insert(), r)
for r in [c._asdict() for c in DB["users"].values()]:
db_engine.execute(user_table.insert(), r)
for r in [p._asdict() for p in DB["attributes"].values()]:
db_engine.execute(attribute_table.insert(), r)
for r in [p._asdict() for p in DB["attribute_values"].values()]:
db_engine.execute(attribute_value_table.insert(), r)
for r in [p._asdict() for p in DB["products"].values()]:
db_engine.execute(product_table.insert(), r)
with db_engine.begin() as db_conn:
for r in [c._asdict() for c in DB["companies"].values()]:
db_conn.execute(company_table.insert(), r)
for r in [c._asdict() for c in DB["users"].values()]:
db_conn.execute(user_table.insert(), r)
for r in [p._asdict() for p in DB["attributes"].values()]:
db_conn.execute(attribute_table.insert(), r)
for r in [p._asdict() for p in DB["attribute_values"].values()]:
db_conn.execute(attribute_value_table.insert(), r)
for r in [p._asdict() for p in DB["products"].values()]:
db_conn.execute(product_table.insert(), r)

if hasattr(db_conn, "commit"):
db_conn.commit()


class Product(t.NamedTuple):
Expand Down
29 changes: 17 additions & 12 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,16 @@ def test_links_requires_list_sa():

def setup_db(db_engine):
metadata.create_all(db_engine)
for row in data["artist"]:
db_engine.execute(artist_table.insert(), row)
for row in data["album"]:
db_engine.execute(album_table.insert(), row)
for row in data["song"]:
db_engine.execute(song_table.insert(), row)
with db_engine.begin() as db_conn:
for row in data["artist"]:
db_conn.execute(artist_table.insert(), row)
for row in data["album"]:
db_conn.execute(album_table.insert(), row)
for row in data["song"]:
db_conn.execute(song_table.insert(), row)

if hasattr(db_conn, "commit"):
db_conn.commit()

sa_engine = create_engine(
"sqlite://",
Expand All @@ -371,12 +375,13 @@ def song_info_fields(ctx, fields, ids):
def get_fields(id_):
album_id = id_["album_id"]
artist_id = id_["artist_id"]
album = db.execute(
album_table.select().where(album_table.c.id == album_id)
).first()
artist = db.execute(
artist_table.select().where(artist_table.c.id == artist_id)
).first()
with db.begin() as db_conn:
album = db_conn.execute(
album_table.select().where(album_table.c.id == album_id)
).first()
artist = db_conn.execute(
artist_table.select().where(artist_table.c.id == artist_id)
).first()

for f in fields:
if f.name == "album_name":
Expand Down
38 changes: 21 additions & 17 deletions tests/test_source_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,27 @@ async def wrapper(*args, **kwargs):

def setup_db(db_engine):
metadata.create_all(db_engine)
for r in [
{"id": 0, "name": "bar0", "type": 4},
{"id": 4, "name": "bar1", "type": 1},
{"id": 5, "name": "bar2", "type": 2},
{"id": 6, "name": "bar3", "type": 3},
]:
db_engine.execute(bar_table.insert(), r)
for r in [
{"name": "foo1", "count": 5, "bar_id": None},
{"name": "foo2", "count": 10, "bar_id": 5},
{"name": "foo3", "count": 15, "bar_id": 4},
{"name": "foo4", "count": 20, "bar_id": 6},
{"name": "foo5", "count": 25, "bar_id": 5},
{"name": "foo6", "count": 30, "bar_id": 4},
{"name": "foo7", "count": 35, "bar_id": 0},
]:
db_engine.execute(foo_table.insert(), r)
with db_engine.begin() as db_conn:
for r in [
{"id": 0, "name": "bar0", "type": 4},
{"id": 4, "name": "bar1", "type": 1},
{"id": 5, "name": "bar2", "type": 2},
{"id": 6, "name": "bar3", "type": 3},
]:
db_conn.execute(bar_table.insert(), r)
for r in [
{"name": "foo1", "count": 5, "bar_id": None},
{"name": "foo2", "count": 10, "bar_id": 5},
{"name": "foo3", "count": 15, "bar_id": 4},
{"name": "foo4", "count": 20, "bar_id": 6},
{"name": "foo5", "count": 25, "bar_id": 5},
{"name": "foo6", "count": 30, "bar_id": 4},
{"name": "foo7", "count": 35, "bar_id": 0},
]:
db_conn.execute(foo_table.insert(), r)

if hasattr(db_conn, "commit"):
db_conn.commit()


def graph_factory(
Expand Down
26 changes: 19 additions & 7 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
[tox]
envlist = py37,pypy3,py38,py39,py310,py311,p312
envlist = py37,pypy3,py38,py39,py310,py311,p312,sqla2,sqla13

[testenv]
groups = dev,test
commands =
pytest tests {posargs}

[testenv:sqla2]
groups = dev,test
commands =
python -I -m pip install 'sqlalchemy>=2.0.0'
pytest tests {posargs}

[testenv:sqla13]
groups = dev,test
commands =
python -I -m pip install 'sqlalchemy>=1.3,<1.4'
pytest tests {posargs}

[gh-actions]
python =
3.7: py37
3.8: py38
3.9: py39
3.10: py310
3.11: py311
3.12: py312
3.7: py37,sqla13
3.8: py38,sqla13
3.9: py39,sqla13
3.10: py310,sqla2
3.11: py311,sqla2
3.12: py312,sqla2

0 comments on commit fe6be9e

Please sign in to comment.