Skip to content

Research self restart #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion business_objects/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Union
from sqlalchemy.orm.session import make_transient as make_transient_original
from ..session import session, engine
from ..session import request_id_ctx_var
from ..session import request_id_ctx_var, exit_on_timeout
from ..session import check_session_and_rollback as check_and_roll
from ..enums import Tablenames, try_parse_enum_value
import traceback
Expand All @@ -18,6 +18,7 @@
session_lookup = {}


@exit_on_timeout
def get_ctx_token() -> Any:
global session_lookup
session_uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -46,6 +47,7 @@ def get_session_lookup(exclude_last_x_seconds: int = 5) -> Dict[str, Dict[str, A
]


@exit_on_timeout
def reset_ctx_token(
ctx_token: Any,
remove_db: Optional[bool] = False,
Expand Down Expand Up @@ -86,21 +88,25 @@ def __close_in_context(session_uuid: str):
del session_lookup[session_uuid]


@exit_on_timeout
def add(entity: Any, with_commit: bool = False) -> None:
session.add(entity)
flush_or_commit(with_commit)


@exit_on_timeout
def add_all(entities: List[Any], with_commit: bool = False) -> None:
session.add_all(entities)
flush_or_commit(with_commit)


@exit_on_timeout
def delete(entity: Any, with_commit: bool = False) -> None:
session.delete(entity)
flush_or_commit(with_commit)


@exit_on_timeout
def commit() -> None:
session.commit()

Expand Down Expand Up @@ -132,18 +138,22 @@ def flush_or_commit(commit: bool = False) -> None:
session.flush()


@exit_on_timeout
def execute(sql: Any, *args) -> Any:
return session.execute(sql, *args)


@exit_on_timeout
def execute_all(sql: str) -> List[Any]:
return session.execute(sql).all()


@exit_on_timeout
def execute_first(sql: str) -> Any:
return session.execute(sql).first()


@exit_on_timeout
def execute_distinct_count(count_sql: str) -> int:
return session.execute(count_sql).first().distinct_count

Expand Down
7 changes: 6 additions & 1 deletion business_objects/user.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from datetime import datetime
from . import general, organization, team_member
from .. import User, enums
from ..session import session
from ..session import session, exit_on_timeout
from typing import List, Optional
from sqlalchemy import sql


from ..util import prevent_sql_injection


@exit_on_timeout
def get(user_id: str) -> User:
return session.query(User).get(user_id)


@exit_on_timeout
def get_by_id_list(user_ids: List[str]) -> List[User]:
return session.query(User).filter(User.id.in_(user_ids)).all()


@exit_on_timeout
def get_all(
organization_id: Optional[str] = None, user_role: Optional[enums.UserRoles] = None
) -> List[User]:
Expand All @@ -28,10 +31,12 @@ def get_all(
return query.all()


@exit_on_timeout
def get_count_assigned() -> int:
return session.query(User.id).filter(User.organization_id != None).count()


@exit_on_timeout
def get_migration_user() -> str:
query = """
SELECT u.id
Expand Down
27 changes: 24 additions & 3 deletions session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any
import os
import sys
import asyncio
import docker
from contextvars import ContextVar
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.exc import PendingRollbackError
from sqlalchemy.exc import PendingRollbackError, TimeoutError
import traceback

from . import daemon
Expand Down Expand Up @@ -32,11 +35,12 @@ def get_request_id():

engine = create_engine(
os.getenv("POSTGRES"),
pool_size=pool_size,
max_overflow=pool_max_overflow,
pool_size=3, # pool_size,
max_overflow=0, # pool_max_overflow,
pool_recycle=pool_recycle,
pool_use_lifo=pool_use_lifo,
pool_pre_ping=pool_pre_ping,
pool_timeout=5,
)

session = scoped_session(
Expand All @@ -52,6 +56,23 @@ def get_request_id():
"""


def exit_on_timeout(f):
def safe_execution(*args, **kwargs):
try:
return f(*args, **kwargs)
except TimeoutError as e:
client = docker.from_env()
container = client.containers.get("cognition-gateway")
loop = asyncio.get_event_loop()
print(f"TimeoutError in {f.__name__}: {e}", flush=True)
traceback.print_exc()
loop.stop()
container.restart()
sys.exit(1)

return safe_execution


def check_session_and_rollback():
try:
_ = session.connection()
Expand Down