Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pymfdata/rdb/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class AsyncRepository(Protocol[_MT, _T]):
session: AsyncSession
_session: AsyncSession

@property
def _model(self):
Expand All @@ -22,6 +22,11 @@ def _model(self):
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

@property
def session(self) -> AsyncSession:
assert self._session is not None
return self._session

async def delete(self, item: _MT):
await self.session.delete(item)

Expand Down Expand Up @@ -68,7 +73,7 @@ def update(self, item: _MT, req: dict):


class SyncRepository(Protocol[_MT, _T]):
session: Session
_session: Session

@property
def _model(self):
Expand All @@ -78,6 +83,11 @@ def _model(self):
def _pk_column(self) -> str:
return inspect(self._model).primary_key[0].name

@property
def session(self) -> Session:
assert self._session is not None
return self._session

@final
def count(self, **kwargs) -> int:
return self._gen_query_for_param(**kwargs).count()
Expand Down
4 changes: 2 additions & 2 deletions tests/rdb/domain/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

class AsyncMemoRepository(AsyncRepository[MemoEntity, int]):
def __init__(self, session: Optional[AsyncSession]) -> None:
self.session = session
self._session = session


class SyncMemoRepository(SyncRepository[MemoEntity, int]):
def __init__(self, session: Optional[Session]) -> None:
self.session = session
self._session = session