diff --git a/sql.py b/sql.py index 7ecabf4..f20e2b2 100644 --- a/sql.py +++ b/sql.py @@ -1,6 +1,7 @@ # vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4 import logging +from contextlib import contextmanager from jsonpickle import encode, decode from typing import Any from sqlalchemy import ( @@ -35,27 +36,38 @@ def __init__(self, session, clazz): self.session = session self.clazz = clazz + @contextmanager + def _session_op(self): + try: + yield self.session + self.session.commit() + except: + self.session.rollback() + raise + def get(self, key: str) -> Any: try: - return self.session.query( - self.clazz).filter(self.clazz._key == key).one().value + with self._session_op() as session: + result = session.query(self.clazz).filter(self.clazz._key == key).one().value except NoResultFound: raise KeyError("%s doesn't exists." % key) + return result def remove(self, key: str): try: - self.session.query( - self.clazz).filter(self.clazz._key == key).delete() - self.session.commit() + with self._session_op() as session: + session.query(self.clazz).filter(self.clazz._key == key).delete() except NoResultFound: raise KeyError("%s doesn't exists." % key) def set(self, key: str, value: Any) -> None: - self.session.merge(self.clazz(key, value)) - self.session.commit() + with self._session_op() as session: + session.merge(self.clazz(key, value)) def len(self): - return self.session.query(self.clazz).count() + with self._session_op() as session: + length = session.query(self.clazz).count() + return length def keys(self): return (kv.key for kv in self.session.query(self.clazz).all()) @@ -89,6 +101,8 @@ def __init__(self, bot_config): else: self._engine = create_engine( config[DATA_URL_ENTRY], + pool_recycle=config.get('connection_recycle', 1800), + pool_pre_ping=config.get('connection_ping', True), echo=bot_config.BOT_LOG_LEVEL == logging.DEBUG) self._metadata = MetaData() self._sessionmaker = sessionmaker()