diff --git a/letta/client/client.py b/letta/client/client.py index af2edcca4a..8a9d3e700a 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -233,7 +233,7 @@ def delete_tool(self, id: str): def get_tool_id(self, name: str) -> Optional[str]: raise NotImplementedError - def add_base_tools(self) -> List[Tool]: + def upsert_base_tools(self) -> List[Tool]: raise NotImplementedError def load_data(self, connector: DataConnector, source_name: str): @@ -1466,7 +1466,7 @@ def get_tool_id(self, tool_name: str): raise ValueError(f"Failed to get tool: {response.text}") return response.json() - def add_base_tools(self) -> List[Tool]: + def upsert_base_tools(self) -> List[Tool]: response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to add base tools: {response.text}") diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index f559bf4a7d..d3ca097b90 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -61,60 +61,6 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O return results_str -def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: - """ - Search prior conversation history using a date range. - - Args: - start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'. - end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'. - page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). - - Returns: - str: Query result string - """ - import math - from datetime import datetime - - from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - from letta.utils import json_dumps - - if page is None or (isinstance(page, str) and page.lower().strip() == "none"): - page = 0 - try: - page = int(page) - if page < 0: - raise ValueError - except: - raise ValueError(f"'page' argument must be an integer") - - # Convert date strings to datetime objects - try: - start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0) - end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999) - except ValueError: - raise ValueError("Dates must be in the format 'YYYY-MM-DD'") - - count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - results = self.message_manager.list_user_messages_for_agent( - # TODO: add paging by page number. currently cursor only works with strings. - agent_id=self.agent_state.id, - actor=self.user, - start_date=start_datetime, - end_date=end_datetime, - limit=count, - ) - total = len(results) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results] - results_str = f"{results_pref} {json_dumps(results_formatted)}" - return results_str - - def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 61b89624d4..15979346c4 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -152,30 +152,15 @@ def update_tool( @router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools") -def add_base_tools( +def upsert_base_tools( server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ - Add base tools + Upsert base tools """ actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.tool_manager.add_base_tools(actor=actor) - - -# NOTE: can re-enable if needed -# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool") -# def run_tool( -# server: SyncServer = Depends(get_letta_server), -# request: ToolRun = Body(...), -# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -# ): -# """ -# Run an existing tool on provided arguments -# """ -# actor = server.user_manager.get_user_or_default(user_id=user_id) - -# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id) + return server.tool_manager.upsert_base_tools(actor=actor) @router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source") diff --git a/letta/server/server.py b/letta/server/server.py index 31c87394d7..56fea7b342 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -301,7 +301,7 @@ def __init__( self.default_org = self.organization_manager.create_default_organization() self.default_user = self.user_manager.create_default_user() self.block_manager.add_default_blocks(actor=self.default_user) - self.tool_manager.add_base_tools(actor=self.default_user) + self.tool_manager.upsert_base_tools(actor=self.default_user) # If there is a default org/user # This logic may have to change in the future diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 99dfa3ae47..aacad8ae49 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,18 +1,18 @@ -from typing import Dict, List, Optional from datetime import datetime -import numpy as np +from typing import Dict, List, Optional -from sqlalchemy import select, union_all, literal, func, Select +import numpy as np +from sqlalchemy import Select, func, literal, select, union_all from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM from letta.embeddings import embedding_model from letta.log import get_logger from letta.orm import Agent as AgentModel +from letta.orm import AgentPassage from letta.orm import Block as BlockModel from letta.orm import Source as SourceModel +from letta.orm import SourcePassage, SourcesAgents from letta.orm import Tool as ToolModel -from letta.orm import AgentPassage, SourcePassage -from letta.orm import SourcesAgents from letta.orm.errors import NoResultFound from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState @@ -77,6 +77,8 @@ def create_agent( tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS) if agent_create.tools: tool_names.extend(agent_create.tools) + # Remove duplicates + tool_names = list(set(tool_names)) tool_ids = agent_create.tool_ids or [] for tool_name in tool_names: @@ -431,7 +433,7 @@ def _build_passage_query( agent_only: bool = False, ) -> Select: """Helper function to build the base passage query with all filters applied. - + Returns the query before any limit or count operations are applied. """ embedded_text = None @@ -448,21 +450,14 @@ def _build_passage_query( if not agent_only: # Include source passages if agent_id is not None: source_passages = ( - select( - SourcePassage, - literal(None).label('agent_id') - ) + select(SourcePassage, literal(None).label("agent_id")) .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) .where(SourcesAgents.agent_id == agent_id) .where(SourcePassage.organization_id == actor.organization_id) ) else: - source_passages = ( - select( - SourcePassage, - literal(None).label('agent_id') - ) - .where(SourcePassage.organization_id == actor.organization_id) + source_passages = select(SourcePassage, literal(None).label("agent_id")).where( + SourcePassage.organization_id == actor.organization_id ) if source_id: @@ -486,9 +481,9 @@ def _build_passage_query( AgentPassage._created_by_id, AgentPassage._last_updated_by_id, AgentPassage.organization_id, - literal(None).label('file_id'), - literal(None).label('source_id'), - AgentPassage.agent_id + literal(None).label("file_id"), + literal(None).label("source_id"), + AgentPassage.agent_id, ) .where(AgentPassage.agent_id == agent_id) .where(AgentPassage.organization_id == actor.organization_id) @@ -496,11 +491,11 @@ def _build_passage_query( # Combine queries if source_passages is not None and agent_passages is not None: - combined_query = union_all(source_passages, agent_passages).cte('combined_passages') + combined_query = union_all(source_passages, agent_passages).cte("combined_passages") elif agent_passages is not None: - combined_query = agent_passages.cte('combined_passages') + combined_query = agent_passages.cte("combined_passages") elif source_passages is not None: - combined_query = source_passages.cte('combined_passages') + combined_query = source_passages.cte("combined_passages") else: raise ValueError("No passages found") @@ -521,9 +516,7 @@ def _build_passage_query( if embedded_text: if settings.letta_pg_uri_no_default: # PostgreSQL with pgvector - main_query = main_query.order_by( - combined_query.c.embedding.cosine_distance(embedded_text).asc() - ) + main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc()) else: # SQLite with custom vector type query_embedding_binary = adapt_array(embedded_text) @@ -531,13 +524,13 @@ def _build_passage_query( main_query = main_query.order_by( func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), combined_query.c.created_at.asc(), - combined_query.c.id.asc() + combined_query.c.id.asc(), ) else: main_query = main_query.order_by( func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), combined_query.c.created_at.desc(), - combined_query.c.id.asc() + combined_query.c.id.asc(), ) else: if query_text: @@ -545,18 +538,12 @@ def _build_passage_query( # Handle cursor-based pagination if cursor: - cursor_query = select(combined_query.c.created_at).where( - combined_query.c.id == cursor - ).scalar_subquery() - + cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery() + if ascending: - main_query = main_query.where( - combined_query.c.created_at > cursor_query - ) + main_query = main_query.where(combined_query.c.created_at > cursor_query) else: - main_query = main_query.where( - combined_query.c.created_at < cursor_query - ) + main_query = main_query.where(combined_query.c.created_at < cursor_query) # Add ordering if not already ordered by similarity if not embed_query: @@ -588,7 +575,7 @@ def list_passages( embed_query: bool = False, ascending: bool = True, embedding_config: Optional[EmbeddingConfig] = None, - agent_only: bool = False + agent_only: bool = False, ) -> List[PydanticPassage]: """Lists all passages attached to an agent.""" with self.session_maker() as session: @@ -617,19 +604,18 @@ def list_passages( passages = [] for row in results: data = dict(row._mapping) - if data['agent_id'] is not None: + if data["agent_id"] is not None: # This is an AgentPassage - remove source fields - data.pop('source_id', None) - data.pop('file_id', None) + data.pop("source_id", None) + data.pop("file_id", None) passage = AgentPassage(**data) else: # This is a SourcePassage - remove agent field - data.pop('agent_id', None) + data.pop("agent_id", None) passage = SourcePassage(**data) passages.append(passage) - - return [p.to_pydantic() for p in passages] + return [p.to_pydantic() for p in passages] @enforce_types def passage_size( @@ -645,7 +631,7 @@ def passage_size( embed_query: bool = False, ascending: bool = True, embedding_config: Optional[EmbeddingConfig] = None, - agent_only: bool = False + agent_only: bool = False, ) -> int: """Returns the count of passages matching the given criteria.""" with self.session_maker() as session: @@ -663,7 +649,7 @@ def passage_size( embedding_config=embedding_config, agent_only=agent_only, ) - + # Convert to count query count_query = select(func.count()).select_from(main_query.subquery()) return session.scalar(count_query) or 0 diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 63240930fe..739bfb382c 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -3,6 +3,7 @@ import warnings from typing import List, Optional +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.functions.functions import derive_openai_json_schema, load_function_set # TODO: Remove this once we translate all of these to the ORM @@ -20,7 +21,6 @@ class ToolManager: BASE_TOOL_NAMES = [ "send_message", "conversation_search", - "conversation_search_date", "archival_memory_insert", "archival_memory_search", ] @@ -133,7 +133,7 @@ def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: raise ValueError(f"Tool with id {tool_id} not found.") @enforce_types - def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: + def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: """Add default tools in base.py""" module_name = "base" full_module_name = f"letta.functions.function_sets.{module_name}" @@ -154,7 +154,7 @@ def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: # create tool in db tools = [] for name, schema in functions_to_schema.items(): - if name in self.BASE_TOOL_NAMES + self.BASE_MEMORY_TOOL_NAMES: + if name in BASE_TOOLS + BASE_MEMORY_TOOLS: # print([str(inspect.getsource(line)) for line in schema["imports"]]) source_code = inspect.getsource(schema["python_function"]) tags = [module_name] diff --git a/scripts/migrate_tools.py b/scripts/migrate_tools.py index 7ea6bac1d6..53578c690e 100644 --- a/scripts/migrate_tools.py +++ b/scripts/migrate_tools.py @@ -1,5 +1,5 @@ -from letta.functions.functions import parse_source_code -from letta.schemas.tool import Tool +from tqdm import tqdm + from letta.schemas.user import User from letta.services.organization_manager import OrganizationManager from letta.services.tool_manager import ToolManager @@ -10,33 +10,8 @@ def deprecated_tool(): orgs = OrganizationManager().list_organizations(cursor=None, limit=5000) -for org in orgs: +for org in tqdm(orgs): if org.name != "default": fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id) - ToolManager().add_base_tools(actor=fake_user) - - source_code = parse_source_code(deprecated_tool) - source_type = "python" - description = "deprecated" - tags = ["deprecated"] - - ToolManager().create_or_update_tool( - Tool( - name="core_memory_append", - source_code=source_code, - source_type=source_type, - description=description, - ), - actor=fake_user, - ) - - ToolManager().create_or_update_tool( - Tool( - name="core_memory_replace", - source_code=source_code, - source_type=source_type, - description=description, - ), - actor=fake_user, - ) + ToolManager().upsert_base_tools(actor=fake_user) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index e0b51255a0..7c634e5fa3 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -11,7 +11,7 @@ from letta import create_client from letta.client.client import LocalClient, RESTClient -from letta.constants import DEFAULT_PRESET +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig @@ -30,7 +30,6 @@ from letta.schemas.message import MessageCreate from letta.schemas.usage import LettaUsageStatistics from letta.services.organization_manager import OrganizationManager -from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import model_settings from tests.helpers.client_helper import upload_file_using_client @@ -336,9 +335,9 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): def test_list_tools(client: Union[LocalClient, RESTClient]): - tools = client.add_base_tools() + tools = client.upsert_base_tools() tool_names = [t.name for t in tools] - expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES + expected = BASE_TOOLS + BASE_MEMORY_TOOLS assert sorted(tool_names) == sorted(expected) diff --git a/tests/test_managers.py b/tests/test_managers.py index dc5f15ad1b..37c6f2ac08 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2,28 +2,27 @@ import time from datetime import datetime, timedelta -from httpx._transports import default -from numpy import source import pytest from sqlalchemy import delete from sqlalchemy.exc import IntegrityError from letta.config import LettaConfig +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import ( Agent, + AgentPassage, Block, BlocksAgents, FileMetadata, Job, Message, Organization, - AgentPassage, - SourcePassage, SandboxConfig, SandboxEnvironmentVariable, Source, + SourcePassage, SourcesAgents, Tool, ToolsAgents, @@ -202,9 +201,9 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"} + metadata_={"type": "test"}, ), - actor=default_user + actor=default_user, ) yield passage @@ -220,9 +219,9 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"} + metadata_={"type": "test"}, ), - actor=default_user + actor=default_user, ) yield passage @@ -240,9 +239,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"} + metadata_={"type": "test"}, ), - actor=default_user + actor=default_user, ) passages.append(passage) if USING_SQLITE: @@ -258,9 +257,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"} + metadata_={"type": "test"}, ), - actor=default_user + actor=default_user, ) passages.append(passage) if USING_SQLITE: @@ -452,7 +451,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent): embedding=[0.1], # Default OpenAI embedding size embedding_config=DEFAULT_EMBEDDING_CONFIG, ), - actor=actor + actor=actor, ) source_passages.append(passage) @@ -467,7 +466,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent): embedding=[0.1], # Default OpenAI embedding size embedding_config=DEFAULT_EMBEDDING_CONFIG, ), - actor=actor + actor=actor, ) agent_passages.append(passage) @@ -476,6 +475,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent): # Cleanup server.source_manager.delete_source(default_source.id, actor=actor) + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -940,32 +940,33 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de # Agent Manager - Passages Tests # ====================================================================================================================== + def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup): """Test basic listing functionality of agent passages""" - + all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id) assert len(all_passages) == 5 # 3 source + 2 agent passages def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup): - """Test ordering of agent passages""" + """Test ordering of agent passages""" # Test ascending order asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True) assert len(asc_passages) == 5 for i in range(1, len(asc_passages)): - assert asc_passages[i-1].created_at <= asc_passages[i].created_at + assert asc_passages[i - 1].created_at <= asc_passages[i].created_at # Test descending order desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False) assert len(desc_passages) == 5 for i in range(1, len(desc_passages)): - assert desc_passages[i-1].created_at >= desc_passages[i].created_at + assert desc_passages[i - 1].created_at >= desc_passages[i].created_at def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup): """Test pagination of agent passages""" - + # Test limit limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3) assert len(limited_passages) == 3 @@ -973,13 +974,9 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent # Test cursor-based pagination first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) assert len(first_page) == 2 - + second_page = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - cursor=first_page[-1].id, - limit=2, - ascending=True + actor=default_user, agent_id=sarah_agent.id, cursor=first_page[-1].id, limit=2, ascending=True ) assert len(second_page) == 2 assert first_page[-1].id != second_page[0].id @@ -988,57 +985,38 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup): """Test text search functionality of agent passages""" - + # Test text search for source passages - source_text_passages = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - query_text="Source passage" - ) + source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage") assert len(source_text_passages) == 3 # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - query_text="Agent passage" - ) + agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage") assert len(agent_text_passages) == 2 def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup): """Test text search functionality of agent passages""" - + # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - agent_only=True - ) + agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agent_text_passages) == 2 def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup): """Test filtering functionality of agent passages""" - + # Test source filtering - source_filtered = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - source_id=default_source.id - ) + source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id) assert len(source_filtered) == 3 # Test date filtering now = datetime.utcnow() future_date = now + timedelta(days=1) past_date = now - timedelta(days=1) - + date_filtered = server.agent_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - start_date=past_date, - end_date=future_date + actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date ) assert len(date_filtered) == 5 @@ -1049,7 +1027,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de # Create passages with known embeddings passages = [] - + # Create passages with different embeddings test_passages = [ "I like red", @@ -1058,7 +1036,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de ] server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) - + for i, text in enumerate(test_passages): embedding = embed_model.get_text_embedding(text) if i % 2 == 0: @@ -1067,7 +1045,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de organization_id=default_user.organization_id, agent_id=sarah_agent.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, - embedding=embedding + embedding=embedding, ) else: passage = PydanticPassage( @@ -1075,14 +1053,14 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de organization_id=default_user.organization_id, source_id=default_source.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, - embedding=embedding + embedding=embedding, ) created_passage = server.passage_manager.create_passage(passage, default_user) passages.append(created_passage) # Query vector similar to "red" embedding query_key = "What's my favorite color?" - + # Test vector search with all passages results = server.agent_manager.list_passages( actor=default_user, @@ -1091,7 +1069,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de embedding_config=DEFAULT_EMBEDDING_CONFIG, embed_query=True, ) - + # Verify results are ordered by similarity assert len(results) == 3 assert results[0].text == "I like red" @@ -1105,9 +1083,9 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de query_text=query_key, embedding_config=DEFAULT_EMBEDDING_CONFIG, embed_query=True, - agent_only=True + agent_only=True, ) - + # Verify agent-only results assert len(agent_only_results) == 2 assert agent_only_results[0].text == "I like red" @@ -1116,7 +1094,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup): """Test listing passages from a source without specifying an agent.""" - + # List passages by source_id without agent_id source_passages = server.agent_manager.list_passages( actor=default_user, @@ -1180,6 +1158,7 @@ def test_list_organizations_pagination(server: SyncServer): # Passage Manager Tests # ====================================================================================================================== + def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user): """Test creating a passage using agent_passage_fixture fixture""" assert agent_passage_fixture.id is not None @@ -1214,7 +1193,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau """Test creating an agent passage.""" assert agent_passage_fixture is not None assert agent_passage_fixture.text == "Hello, I am an agent passage" - + # Try to create an invalid passage (with both agent_id and source_id) with pytest.raises(AssertionError): server.passage_manager.create_passage( @@ -1226,7 +1205,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau embedding=[0.1] * 1024, embedding_config=DEFAULT_EMBEDDING_CONFIG, ), - actor=default_user + actor=default_user, ) @@ -1243,19 +1222,21 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas assert retrieved.text == source_passage_fixture.text -def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent): +def test_passage_cascade_deletion( + server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent +): """Test that passages are deleted when their parent (agent or source) is deleted.""" # Verify passages exist agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user) source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user) assert agent_passage is not None assert source_passage is not None - + # Delete agent and verify its passages are deleted server.agent_manager.delete_agent(sarah_agent.id, default_user) agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agentic_passages) == 0 - + # Delete source and verify its passages are deleted server.source_manager.delete_source(default_source.id, default_user) with pytest.raises(NoResultFound): @@ -1320,7 +1301,6 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ assert print_tool.organization_id == default_organization.id - @pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.") def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization): data = print_tool.model_dump(exclude=["id"]) @@ -1481,6 +1461,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user): assert len(tools) == 0 +def test_upsert_base_tools(server: SyncServer, default_user): + tools = server.tool_manager.upsert_base_tools(actor=default_user) + expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS) + assert sorted([t.name for t in tools]) == expected_tool_names + + # Call it again to make sure it doesn't create duplicates + tools = server.tool_manager.upsert_base_tools(actor=default_user) + assert sorted([t.name for t in tools]) == expected_tool_names + + # ====================================================================================================================== # Message Manager Tests # ====================================================================================================================== @@ -1889,6 +1879,7 @@ def test_update_source_no_changes(server: SyncServer, default_user): # Source Manager Tests - Files # ====================================================================================================================== + def test_get_file_by_id(server: SyncServer, default_user, default_source): """Test retrieving a file by ID.""" file_metadata = PydanticFileMetadata( @@ -1960,6 +1951,7 @@ def test_delete_file(server: SyncServer, default_user, default_source): # SandboxConfigManager Tests - Sandbox Configs # ====================================================================================================================== + def test_create_or_update_sandbox_config(server: SyncServer, default_user): sandbox_config_create = SandboxConfigCreate( config=E2BSandboxConfig(), @@ -2039,6 +2031,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user): # SandboxConfigManager Tests - Environment Variables # ====================================================================================================================== + def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user): env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.") created_env_var = server.sandbox_config_manager.create_sandbox_env_var( @@ -2111,6 +2104,7 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, # JobManager Tests # ====================================================================================================================== + def test_create_job(server: SyncServer, default_user): """Test creating a job.""" job_data = PydanticJob( diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index d82bbc11a3..2865bb2ec4 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -272,15 +272,15 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer ) -def test_add_base_tools(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool] +def test_upsert_base_tools(client, mock_sync_server, add_integers_tool): + mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool] response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"}) assert response.status_code == 200 assert len(response.json()) == 1 assert response.json()[0]["id"] == add_integers_tool.id - mock_sync_server.tool_manager.add_base_tools.assert_called_once_with( + mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with( actor=mock_sync_server.user_manager.get_user_or_default.return_value )