Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast

from pydantic import BaseModel
from sqlalchemy import select

from core.agent.entities import AgentEntity, AgentToolEntity
Expand Down Expand Up @@ -41,11 +43,28 @@
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

logger = logging.getLogger(__name__)


class AgentThoughtValidation(BaseModel):
"""
Validation model for agent thought data before database persistence.
"""

message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_input: str | None = None
observation: str | None = None

class Config:
extra = "allow" # Pydantic v1 syntax - should use ConfigDict(extra='forbid')


class BaseAgentRunner(AppRunner):
def __init__(
self,
Expand Down Expand Up @@ -289,27 +308,28 @@ def create_agent_thought(
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal("0.001"),
answer_price_unit=Decimal(0),
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)

Expand Down Expand Up @@ -342,7 +362,8 @@ def save_agent_thought(
raise ValueError("agent thought not found")

if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"

if tool_name:
agent_thought.tool = tool_name
Expand Down Expand Up @@ -440,21 +461,30 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)

for tool in tools:
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}

observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)

for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
Expand All @@ -469,7 +499,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
content=str(tool_inputs.get(tool, agent_thought.observation)),
name=tool,
tool_call_id=tool_call_id,
)
Expand All @@ -484,7 +514,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
Expand Down
62 changes: 35 additions & 27 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,42 +1835,50 @@ class MessageChain(TypeBase):
)


class MessageAgentThought(Base):
class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
sa.Index("message_agent_thought_message_id_idx", "message_id"),
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)

id = mapped_column(StringUUID, default=lambda: str(uuid4()))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(LongText, nullable=True)
tool = mapped_column(LongText, nullable=True)
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input = mapped_column(LongText, nullable=True)
observation = mapped_column(LongText, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(LongText, nullable=True)
message = mapped_column(LongText, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(LongText, nullable=True)
answer = mapped_column(LongText, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String(255), nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
)

@property
def files(self) -> list[Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create first agent thought
thought1 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand All @@ -257,7 +256,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create second agent thought
thought2 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=2,
thought="Based on the analysis, I can provide a response",
Expand Down Expand Up @@ -545,7 +543,6 @@ def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_e

# Create agent thought with tool error
thought_with_error = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -759,7 +756,6 @@ def test_get_agent_logs_with_complex_tool_data(

# Create agent thought with multiple tools
complex_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to use multiple tools to complete this task",
Expand Down Expand Up @@ -877,7 +873,6 @@ def test_get_agent_logs_with_files(self, db_session_with_containers, mock_extern

# Create agent thought with files
thought_with_files = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to process some files",
Expand Down Expand Up @@ -957,7 +952,6 @@ def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, m

# Create agent thought with empty tool data
empty_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -999,7 +993,6 @@ def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mo

# Create agent thought with malformed JSON
malformed_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down