diff --git a/letta/server/server.py b/letta/server/server.py index 1a5de01e96..8f87e611a0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -156,6 +156,11 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N raise NotImplementedError +from contextlib import contextmanager + +from rich.console import Console +from rich.panel import Panel +from rich.text import Text from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -166,6 +171,37 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N config = LettaConfig.load() + +def print_sqlite_schema_error(): + """Print a formatted error message for SQLite schema issues""" + console = Console() + error_text = Text() + error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red") + error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white") + error_text.append("https://docs.letta.com/server/docker", style="blue underline") + error_text.append(") or use Postgres by setting ", style="white") + error_text.append("LETTA_PG_URI", style="yellow") + error_text.append(".\n\n", style="white") + error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white") + error_text.append("rm ~/.letta/sqlite.db", style="yellow") + error_text.append(" or downgrade to your previous version of Letta.", style="white") + + console.print(Panel(error_text, border_style="red")) + + +@contextmanager +def db_error_handler(): + """Context manager for handling database errors""" + try: + yield + except Exception as e: + # Handle other SQLAlchemy errors + print(e) + print_sqlite_schema_error() + # raise ValueError(f"SQLite DB error: {str(e)}") + exit(1) + + if settings.letta_pg_uri_no_default: config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default @@ -178,6 +214,30 @@ def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, N # TODO: don't rely on config storage engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")) + # Store the original connect method + original_connect = engine.connect + + def wrapped_connect(*args, **kwargs): + with db_error_handler(): + # Get the connection + connection = original_connect(*args, **kwargs) + + # Store the original execution method + original_execute = connection.execute + + # Wrap the execute method of the connection + def wrapped_execute(*args, **kwargs): + with db_error_handler(): + return original_execute(*args, **kwargs) + + # Replace the connection's execute method + connection.execute = wrapped_execute + + return connection + + # Replace the engine's connect method + engine.connect = wrapped_connect + Base.metadata.create_all(bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -379,7 +439,9 @@ def initialize_agent(self, agent_id, interface: Union[AgentInterface, None] = No if agent_state.agent_type == AgentType.memgpt_agent: agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence) elif agent_state.agent_type == AgentType.offline_memory_agent: - agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence) + agent = OfflineMemoryAgent( + agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence + ) else: assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents" agent = O1Agent(agent_state=agent_state, interface=interface, user=actor) @@ -500,8 +562,8 @@ def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStati letta_agent.attach_source( user=self.user_manager.get_user_by_id(user_id=user_id), source_id=data_source, - source_manager=letta_agent.source_manager, - ms=self.ms + source_manager=letta_agent.source_manager, + ms=self.ms, ) elif command.lower() == "dump" or command.lower().startswith("dump "): @@ -1267,7 +1329,10 @@ def get_agent_archival_cursor( # iterate over records records = letta_agent.passage_manager.list_passages( - actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit, + actor=self.default_user, + agent_id=agent_id, + cursor=cursor, + limit=limit, ) return records @@ -1914,7 +1979,7 @@ def run_tool_from_source( date=get_utc_time(), status="error", function_return=error_msg, - stdout=[''], + stdout=[""], stderr=[traceback.format_exc()], ) diff --git a/poetry.lock b/poetry.lock index e6cf38e0dd..dda0b0bac7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6222,4 +6222,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<4.0,>=3.10" -content-hash = "78621cd10122e3b41658020d0711b61ffa379b251e2a012854f8a92cc37ff3c0" +content-hash = "9c623c4d8c98b3fe724518428bb48ae85f8152453f200f767e13f48c59e0fe13" diff --git a/pyproject.toml b/pyproject.toml index a0eb14abbe..e0531b10c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ pathvalidate = "^3.2.1" langchain-community = {version = "^0.3.7", optional = true} langchain = {version = "^0.3.7", optional = true} sentry-sdk = {extras = ["fastapi"], version = "2.19.1"} +rich = "^13.9.4" brotli = "^1.1.0" grpcio = "^1.68.1" grpcio-tools = "^1.68.1"