Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Clean up upserting base tools #2274

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def delete_tool(self, id: str):
def get_tool_id(self, name: str) -> Optional[str]:
raise NotImplementedError

def add_base_tools(self) -> List[Tool]:
def upsert_base_tools(self) -> List[Tool]:
raise NotImplementedError

def load_data(self, connector: DataConnector, source_name: str):
Expand Down Expand Up @@ -1466,7 +1466,7 @@ def get_tool_id(self, tool_name: str):
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()

def add_base_tools(self) -> List[Tool]:
def upsert_base_tools(self) -> List[Tool]:
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to add base tools: {response.text}")
Expand Down
54 changes: 0 additions & 54 deletions letta/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,60 +61,6 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
return results_str


def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using a date range.

Args:
start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'.
end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'.
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).

Returns:
str: Query result string
"""
import math
from datetime import datetime

from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps

if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
if page < 0:
raise ValueError
except:
raise ValueError(f"'page' argument must be an integer")

# Convert date strings to datetime objects
try:
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0)
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError("Dates must be in the format 'YYYY-MM-DD'")

count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results = self.message_manager.list_user_messages_for_agent(
# TODO: add paging by page number. currently cursor only works with strings.
agent_id=self.agent_state.id,
actor=self.user,
start_date=start_datetime,
end_date=end_datetime,
limit=count,
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str


def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
Expand Down
21 changes: 3 additions & 18 deletions letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,30 +152,15 @@ def update_tool(


@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
def add_base_tools(
def upsert_base_tools(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Add base tools
Upsert base tools
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)


# NOTE: can re-enable if needed
# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
# def run_tool(
# server: SyncServer = Depends(get_letta_server),
# request: ToolRun = Body(...),
# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
# ):
# """
# Run an existing tool on provided arguments
# """
# actor = server.user_manager.get_user_or_default(user_id=user_id)

# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
return server.tool_manager.upsert_base_tools(actor=actor)


@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
Expand Down
2 changes: 1 addition & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
self.default_org = self.organization_manager.create_default_organization()
self.default_user = self.user_manager.create_default_user()
self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.add_base_tools(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user)

# If there is a default org/user
# This logic may have to change in the future
Expand Down
78 changes: 32 additions & 46 deletions letta/services/agent_manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np
from typing import Dict, List, Optional

from sqlalchemy import select, union_all, literal, func, Select
import numpy as np
from sqlalchemy import Select, func, literal, select, union_all

from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
Expand Down Expand Up @@ -77,6 +77,8 @@ def create_agent(
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
# Remove duplicates
tool_names = list(set(tool_names))

tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
Expand Down Expand Up @@ -431,7 +433,7 @@ def _build_passage_query(
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.

Returns the query before any limit or count operations are applied.
"""
embedded_text = None
Expand All @@ -448,21 +450,14 @@ def _build_passage_query(
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
select(SourcePassage, literal(None).label("agent_id"))
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.where(SourcePassage.organization_id == actor.organization_id)
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
SourcePassage.organization_id == actor.organization_id
)

if source_id:
Expand All @@ -486,21 +481,21 @@ def _build_passage_query(
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
literal(None).label("file_id"),
literal(None).label("source_id"),
AgentPassage.agent_id,
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
)

# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
elif agent_passages is not None:
combined_query = agent_passages.cte('combined_passages')
combined_query = agent_passages.cte("combined_passages")
elif source_passages is not None:
combined_query = source_passages.cte('combined_passages')
combined_query = source_passages.cte("combined_passages")
else:
raise ValueError("No passages found")

Expand All @@ -521,42 +516,34 @@ def _build_passage_query(
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
if ascending:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))

# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()

cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()

if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
main_query = main_query.where(combined_query.c.created_at > cursor_query)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)
main_query = main_query.where(combined_query.c.created_at < cursor_query)

# Add ordering if not already ordered by similarity
if not embed_query:
Expand Down Expand Up @@ -588,7 +575,7 @@ def list_passages(
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
Expand Down Expand Up @@ -617,19 +604,18 @@ def list_passages(
passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)

return [p.to_pydantic() for p in passages]

return [p.to_pydantic() for p in passages]

@enforce_types
def passage_size(
Expand All @@ -645,7 +631,7 @@ def passage_size(
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
Expand All @@ -663,7 +649,7 @@ def passage_size(
embedding_config=embedding_config,
agent_only=agent_only,
)

# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0
Expand Down
6 changes: 3 additions & 3 deletions letta/services/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import List, Optional

from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.functions.functions import derive_openai_json_schema, load_function_set

# TODO: Remove this once we translate all of these to the ORM
Expand All @@ -20,7 +21,6 @@ class ToolManager:
BASE_TOOL_NAMES = [
"send_message",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
"archival_memory_search",
]
Expand Down Expand Up @@ -133,7 +133,7 @@ def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
raise ValueError(f"Tool with id {tool_id} not found.")

@enforce_types
def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py"""
module_name = "base"
full_module_name = f"letta.functions.function_sets.{module_name}"
Expand All @@ -154,7 +154,7 @@ def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
if name in self.BASE_TOOL_NAMES + self.BASE_MEMORY_TOOL_NAMES:
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
Expand Down
33 changes: 4 additions & 29 deletions scripts/migrate_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from letta.functions.functions import parse_source_code
from letta.schemas.tool import Tool
from tqdm import tqdm

from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
Expand All @@ -10,33 +10,8 @@ def deprecated_tool():


orgs = OrganizationManager().list_organizations(cursor=None, limit=5000)
for org in orgs:
for org in tqdm(orgs):
if org.name != "default":
mattzh72 marked this conversation as resolved.
Show resolved Hide resolved
fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id)

ToolManager().add_base_tools(actor=fake_user)

source_code = parse_source_code(deprecated_tool)
source_type = "python"
description = "deprecated"
tags = ["deprecated"]

ToolManager().create_or_update_tool(
Tool(
name="core_memory_append",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)

ToolManager().create_or_update_tool(
Tool(
name="core_memory_replace",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
ToolManager().upsert_base_tools(actor=fake_user)
Loading
Loading