Skip to content

Commit

Permalink
Fix tool rules validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Dec 20, 2024
1 parent ea88617 commit d8c9d3b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
18 changes: 5 additions & 13 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
Expand Down Expand Up @@ -47,7 +46,10 @@
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_memory_metadata_block
from letta.services.helpers.agent_manager_helper import (
check_supports_structured_output,
compile_memory_metadata_block,
)
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
Expand Down Expand Up @@ -121,7 +123,7 @@ def __init__(

# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
self.check_tool_rules()
self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules)

# state managers
self.block_manager = BlockManager()
Expand Down Expand Up @@ -152,16 +154,6 @@ def __init__(
# Load last function response from message history
self.last_function_response = self.load_last_function_response()

def check_tool_rules(self):
if self.model not in STRUCTURED_OUTPUT_MODELS:
if len(self.tool_rules_solver.init_tool_rules) > 1:
raise ValueError(
"Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule."
)
self.supports_structured_output = False
else:
self.supports_structured_output = True

def load_last_function_response(self):
"""Load the last function response from message history"""
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
Expand Down
5 changes: 5 additions & 0 deletions letta/services/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from letta.services.helpers.agent_manager_helper import (
_process_relationship,
_process_tags,
check_supports_structured_output,
compile_system_message,
derive_system_message,
initialize_message_sequence,
Expand Down Expand Up @@ -70,6 +71,10 @@ def create_agent(
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")

# Check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)

# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
Expand Down
13 changes: 12 additions & 1 deletion letta/services/helpers/agent_manager_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List, Literal, Optional

from letta import system
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
from letta.helpers import ToolRulesSolver
from letta.orm.agent import Agent as AgentModel
from letta.orm.agents_tags import AgentsTags
from letta.orm.errors import NoResultFound
Expand All @@ -11,6 +12,7 @@
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event
from letta.utils import get_local_time
Expand Down Expand Up @@ -247,3 +249,12 @@ def package_initial_message_sequence(
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model)
)
return init_messages


def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
if model not in STRUCTURED_OUTPUT_MODELS:
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
return False
else:
return True

0 comments on commit d8c9d3b

Please sign in to comment.