Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

114 changes: 62 additions & 52 deletions pymfdata/rdb/repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Callable, final, Iterator, List, Protocol, TypeVar, Optional
from typing import Callable, final, Iterator, get_args, List, Protocol, Optional, Type, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Session, Query
from sqlalchemy.sql.selectable import Select

Expand All @@ -12,33 +13,35 @@


class AsyncRepository(Protocol[_MT, _T]):
_model: _MT
_session_factory: Callable[..., AbstractAsyncContextManager]
_pk_column: str

@property
def _model(self):
return get_args(self.__orig_bases__[0])[0]

@property
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

async def delete_by_pk(self, pk: _T) -> bool:
item = await self.find_by_pk(pk)
if item is not None:
session: AsyncSession
async with self._session_factory() as session:
session: AsyncSession
async with self._session_factory() as session:
item = await self.find_by_pk(session, pk)
if item is not None:
await session.delete(item)
await session.commit()

return True
return False
return True

async def find_by_pk(self, pk: _T) -> Optional[_MT]:
return await self.find_by_col(**{self._pk_column: pk})
return False

@final
async def find_by_col(self, **kwargs) -> Optional[_MT]:
if not await self.is_exists(**kwargs):
return None
async def find_by_pk(self, session: AsyncSession, pk: _T) -> Optional[_MT]:
return await self.find_by_col(session, **{self._pk_column: pk})

session: AsyncSession
async with self._session_factory() as session:
item = await session.execute(self._gen_stmt_for_param(**kwargs))
return item.unique().scalars().one()
@final
async def find_by_col(self, session: AsyncSession, **kwargs) -> Optional[_MT]:
item = await session.execute(self._gen_stmt_for_param(**kwargs))
return item.unique().scalars().one_or_none()

@final
def _gen_stmt_for_param(self, **kwargs) -> Select:
Expand All @@ -61,59 +64,66 @@ async def find_all(self, **kwargs) -> List[_MT]:
async def is_exists(self, **kwargs) -> bool:
session: AsyncSession
async with self._session_factory() as session:
return await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
result = await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
return result.scalar()

@final
async def save(self, item: Base):
async def save(self, item: _MT):
session: AsyncSession
async with self._session_factory() as session:
session.add(item)
await session.commit()
await session.refresh(item)

async def update_by_pk(self, pk: _T, req: dict) -> bool:
item = await self.find_by_pk(pk)
if item is not None:
session: AsyncSession
async with self._session_factory() as session:
session: AsyncSession
async with self._session_factory() as session:
item = await self.find_by_pk(session, pk)
if item is not None:
for k, v in req.items():
if v is not None:
setattr(item, k, v)

await session.commit()
await session.refresh(item)

return True
return False
return True
return False


class SyncRepository(Protocol[_MT, _T]):
_model: _MT
_session_factory: Callable[..., AbstractContextManager]
_pk_column: str

@property
def _model(self):
return get_args(self.__orig_bases__[0])[0]

@property
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

@final
def count(self, **kwargs) -> int:
return self._gen_query_for_param(**kwargs).count()

def delete_by_pk(self, pk: _T) -> bool:
item = self.find_by_pk(pk)
if item is not None:
session: Session
with self._session_factory() as session:
session: Session
with self._session_factory() as session:
item = self.find_by_pk(session, pk)
if item is not None:
session.delete(item)
session.commit()

return True
return False
return True
return False

def find_by_pk(self, pk: _T) -> Optional[_MT]:
return self.find_by_col(**{self._pk_column: pk})
def find_by_pk(self, session: Session, pk: _T) -> Optional[_MT]:
return self.find_by_col(session, **{self._pk_column: pk})

@final
def find_by_col(self, **kwargs) -> Optional[_MT]:
if not self.is_exists(**kwargs):
return None

with self._session_factory() as session:
query = self._gen_query_for_param(session, **kwargs)
return query.one()
def find_by_col(self, session: Session, **kwargs) -> Optional[_MT]:
query = self._gen_query_for_param(session, **kwargs)
return query.one_or_none()

@final
def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
Expand Down Expand Up @@ -144,16 +154,16 @@ def save(self, item: Base):
session.commit()

def update_by_pk(self, pk: _T, req: dict) -> bool:
item = self.find_by_pk(pk)
if item is not None:
session: Session
with self._session_factory() as session:
session: Session
with self._session_factory() as session:
item = self.find_by_pk(session, pk)
if item is not None:
for k, v in req.items():
if v is not None:
setattr(item, k, v)

await session.commit()
await session.refresh(item)
session.commit()
session.refresh(item)

return True
return False
return True
return False