Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions pymfdata/rdb/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@
class AsyncSQLAlchemyUnitOfWork(AsyncBaseUnitOfWork):
def __init__(self, engine: AsyncEngine) -> None:
self._engine = engine
self.session: Optional[AsyncSession] = None
self._session: Optional[AsyncSession] = None

@property
def engine(self) -> AsyncEngine:
assert self._engine is not None
return self._engine

@property
def session(self) -> AsyncSession:
assert self._session is not None
return self._session

async def __aenter__(self):
self.session = AsyncSession(self._engine)
self._session = AsyncSession(self.engine)

async def __aexit__(self, exc_type: Optional[Type[Exception]], exc_val: Optional[Exception], traceback):
await super().__aexit__(exc_type, exc_val, traceback)
await self.session.close()

async def commit(self):
Expand All @@ -32,13 +43,24 @@ async def rollback(self):

class SyncSQLAlchemyUnitOfWork(SyncBaseUnitOfWork):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.session: Optional[Session] = None
self._engine = engine
self._session: Optional[Session] = None

@property
def engine(self) -> Engine:
assert self._engine is not None
return self._engine

@property
def session(self) -> Session:
assert self._session is not None
return self._session

def __enter__(self):
self.session = Session(self.engine)
self._session = Session(self.engine)

def __exit__(self, exc_type: Optional[Type[Exception]], exc_val: Optional[Exception], traceback):
super().__exit__(exc_type, exc_val, traceback)
self.session.close()

def commit(self):
Expand Down