Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions aworld/agents/llm_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import json
import os
import traceback
import uuid
from collections import OrderedDict
from datetime import datetime
from typing import Dict, Any, List, Callable, Optional

import aworld.trace as trace
from aworld.config.conf import TaskConfig, TaskRunMode
from aworld.config.conf import TaskConfig, TaskRunMode, AgentConfig
from aworld.core.agent.agent_desc import get_agent_desc
from aworld.core.agent.base import BaseAgent, AgentResult, is_agent_by_name, is_agent, AgentFactory
from aworld.core.common import ActionResult, Observation, ActionModel, Config, TaskItem
Expand All @@ -18,6 +19,7 @@
MemoryEventType as MemoryType, MemoryEventMessage, ChunkMessage
from aworld.core.exceptions import AWorldRuntimeException
from aworld.core.model_output_parser import ModelOutputParser
from aworld.core.model_output_parser.hermes_tool_parser import HermesToolParser
from aworld.core.tool.tool_desc import get_tool_desc
from aworld.events import eventbus
from aworld.events.util import send_message, send_message_with_future
Expand Down Expand Up @@ -53,19 +55,15 @@ async def parse(self, resp: ModelResponse, **kwargs) -> AgentResult:

results = []
is_call_tool = False
content = '' if resp.content is None else resp.content
if kwargs.get("use_tools_in_prompt"):
tool_calls = []
for tool in self.use_tool_list(content):
tool_calls.append(ToolCall.from_dict({
"id": tool.get("id"),
"function": {
"name": tool.get("tool"),
"arguments": tool.get("arguments")
}
}))
if tool_calls:
resp.tool_calls = tool_calls
if not self.get_parser("tool"):
self.register_parser(HermesToolParser())

# Iterate over all registered parsers to process the response.
# This allows for extensible parsing logic where multiple parsers can contribute to the final result.
for content_parser in self.get_parsers().values():
resp = await content_parser.parse(resp, **kwargs)
content = '' if resp.content is None else resp.content

if resp.tool_calls:
is_call_tool = True
Expand Down Expand Up @@ -131,7 +129,7 @@ def use_tool_list(self, content: str) -> List[Dict[str, Any]]:
return tool_list


class Agent(BaseAgent[Observation, List[ActionModel]]):
class LLMAgent(BaseAgent[Observation, List[ActionModel]]):
"""Basic agent for unified protocol within the framework."""

def __init__(self,
Expand Down Expand Up @@ -171,6 +169,27 @@ def __init__(self,
tool_aggregate_func: Aggregation strategy for multiple tool results.
event_handler_name: Custom handlers for certain types of events.
"""
if conf is None:
model_name = os.getenv("LLM_MODEL_NAME")
api_key = os.getenv("LLM_API_KEY")
base_url = os.getenv("LLM_BASE_URL")

assert api_key and model_name, (
"LLM_MODEL_NAME and LLM_API_KEY (environment variables) must be set, "
"or pass AgentConfig explicitly"
)
logger.info(f"AgentConfig is empty, using env variables:\n"
f"LLM_API_KEY={'*' * min(len(api_key), 8) if api_key else 'Not set'}\n"
f"LLM_BASE_URL={base_url}\n"
f"LLM_MODEL_NAME={model_name}")

conf = AgentConfig(
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
llm_model_name=model_name,
llm_api_key=api_key,
llm_base_url=base_url,
llm_temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")),
)
super(Agent, self).__init__(name, conf, desc, agent_id,
task=task,
tool_names=tool_names,
Expand Down Expand Up @@ -376,7 +395,8 @@ async def async_messages_transform(self,
# Maintain the order of tool calls
for tool_call_id in last_tool_calls:
if tool_call_id not in tool_calls_map:
raise AWorldRuntimeException(f"tool_calls mismatch! {tool_call_id} not found in {tool_calls_map}, messages: {messages}")
raise AWorldRuntimeException(
f"tool_calls mismatch! {tool_call_id} not found in {tool_calls_map}, messages: {messages}")
messages.append(tool_calls_map.get(tool_call_id))
tool_calls_map = {}
last_tool_calls = []
Expand All @@ -398,7 +418,8 @@ async def async_messages_transform(self,
if not self.use_tools_in_prompt and history.metadata.get('tool_calls'):
messages.append({'role': history.metadata['role'], 'content': history.content,
'tool_calls': [history.metadata['tool_calls']]})
last_tool_calls.extend([tool_call.get('id') for tool_call in history.metadata['tool_calls']])
last_tool_calls.extend(
[tool_call.get('id') for tool_call in history.metadata['tool_calls']])
else:
messages.append({'role': history.metadata['role'], 'content': history.content,
"tool_call_id": history.metadata.get("tool_call_id")})
Expand Down Expand Up @@ -543,7 +564,8 @@ async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}

try:
events = []
async for event in run_hooks(message.context, HookPoint.PRE_LLM_CALL, hook_from=self.id(), payload=observation):
async for event in run_hooks(message.context, HookPoint.PRE_LLM_CALL, hook_from=self.id(),
payload=observation):
events.append(event)
except Exception as e:
logger.error(f"{self.id()} failed to run PRE_LLM_CALL hooks: {e}, traceback is {traceback.format_exc()}")
Expand Down Expand Up @@ -581,7 +603,8 @@ async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}

try:
events = []
async for event in run_hooks(message.context, HookPoint.POST_LLM_CALL, hook_from=self.id(), payload=llm_response):
async for event in run_hooks(message.context, HookPoint.POST_LLM_CALL, hook_from=self.id(),
payload=llm_response):
events.append(event)
except Exception as e:
logger.error(
Expand All @@ -590,6 +613,7 @@ async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}
else:
logger.error(f"{self.id()} failed to get LLM response")
raise RuntimeError(f"{self.id()} failed to get LLM response")

logger.info(f"agent_result: {agent_result}")

if self.is_agent_finished(llm_response, agent_result):
Expand Down Expand Up @@ -644,7 +668,8 @@ async def execution_tools(self, actions: List[ActionModel], message: Message = N
# tool hooks
try:
events = []
async for event in run_hooks(context=message.context, hook_point=HookPoint.POST_TOOL_CALL, hook_from=self.id(), payload=act_result):
async for event in run_hooks(context=message.context, hook_point=HookPoint.POST_TOOL_CALL,
hook_from=self.id(), payload=act_result):
events.append(event)
except Exception:
logger.debug(traceback.format_exc())
Expand Down Expand Up @@ -819,7 +844,8 @@ async def custom_system_prompt(self, context: Context, content: str, tool_list:
system_prompt += PTC_NEURON_PROMPT
return system_prompt

async def _add_message_to_memory(self, payload: Any, message_type: MemoryType, context: Context, skip_summary: bool = False):
async def _add_message_to_memory(self, payload: Any, message_type: MemoryType, context: Context,
skip_summary: bool = False):
memory_msg = MemoryEventMessage(
payload=payload,
agent=self,
Expand Down Expand Up @@ -939,3 +965,7 @@ async def process_by_ptc(self, tools, context: Context):
tool["function"]["description"] = "[allow_code_execution]" + tool["function"]["description"]
logger.debug(f"ptc augmented tool: {tool['function']['description']}")



# Considering compatibility and current universality, we still use Agent to represent LLM Agent.
Agent = LLMAgent
24 changes: 0 additions & 24 deletions aworld/core/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,30 +105,6 @@ def __init__(
wait_tool_result: Whether wait on the results of the tool.
sandbox: Sandbox instance for tool execution, advanced usage.
"""
if conf is None:
model_name = os.getenv("LLM_MODEL_NAME")
api_key = os.getenv("LLM_API_KEY")
base_url = os.getenv("LLM_BASE_URL")

assert api_key and model_name, (
"LLM_MODEL_NAME and LLM_API_KEY (environment variables) must be set, "
"or pass AgentConfig explicitly"
)
logger.info(f"AgentConfig is empty, using env variables:\n"
f"LLM_API_KEY={'*' * min(len(api_key), 8) if api_key else 'Not set'}\n"
f"LLM_BASE_URL={base_url}\n"
f"LLM_MODEL_NAME={model_name}")

conf = AgentConfig(
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
llm_model_name=model_name,
llm_api_key=api_key,
llm_base_url=base_url,
llm_temperature=float(os.getenv("LLM_TEMPERATURE", "0.7")),
)
else:
self.conf = conf

if isinstance(conf, ConfigDict):
pass
elif isinstance(conf, Dict):
Expand Down
20 changes: 0 additions & 20 deletions aworld/core/model_output_parser.py

This file was deleted.

8 changes: 8 additions & 0 deletions aworld/core/model_output_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.

from .base_content_parser import BaseContentParser
from .model_output_parser import ModelOutputParser

__all__ = ["ModelOutputParser", "BaseContentParser"]

22 changes: 22 additions & 0 deletions aworld/core/model_output_parser/base_content_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
from typing import Any

from aworld.logs.util import logger
from aworld.models.model_response import ModelResponse


class BaseContentParser(abc.ABC):
"""Base class for all concrete content parsers"""

@property
@abc.abstractmethod
def parser_type(self) -> str:
"""Parser type, e.g. 'tool', 'reasoning', 'code'"""
pass

@abc.abstractmethod
async def parse(self, resp: ModelResponse, **kwargs) -> Any:
"""Parse text content and return structured data"""
pass
59 changes: 59 additions & 0 deletions aworld/core/model_output_parser/default_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.

from typing import Any, Dict, List, Optional
from aworld.core.model_output_parser.base_content_parser import BaseContentParser
from aworld.models.model_response import ModelResponse


class ToolParser(BaseContentParser):
"""Default parser for tool calls."""

@property
def parser_type(self) -> str:
return "tool"

def parse(self, resp: ModelResponse, **kwargs) -> Any:
"""Parse tool calls from model response."""
# Default implementation placeholder
return resp


class ReasoningParser(BaseContentParser):
"""Default parser for reasoning/thinking process."""

@property
def parser_type(self) -> str:
return "thinking"

def parse(self, resp: ModelResponse, **kwargs) -> Any:
"""Parse reasoning content from model response."""
# Default implementation placeholder
return resp


class CodeParser(BaseContentParser):
"""Default parser for code blocks."""

@property
def parser_type(self) -> str:
return "code"

def parse(self, resp: ModelResponse, **kwargs) -> Any:
"""Parse code blocks from model response."""
# Default implementation placeholder
return resp


class JsonParser(BaseContentParser):
"""Default parser for JSON content."""

@property
def parser_type(self) -> str:
return "json"

def parse(self, resp: ModelResponse, **kwargs) -> Any:
"""Parse JSON content from model response."""
# Default implementation placeholder
return resp

44 changes: 44 additions & 0 deletions aworld/core/model_output_parser/hermes_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import json
import re
import uuid

from aworld.core.model_output_parser.base_content_parser import BaseContentParser
from aworld.core.model_output_parser.default_parsers import ToolParser
from aworld.logs.util import logger
from aworld.models.model_response import ModelResponse, ToolCall, Function



class HermesToolParser(ToolParser):
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py."""

def __init__(self) -> None:
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)

async def extract_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
if self.tool_call_start_token not in content or self.tool_call_end_token not in content:
return content, []

matches = self.tool_call_regex.findall(content)
function_calls = []
for match in matches:
try:
function_call = json.loads(match)
name, arguments = function_call["name"], function_call["arguments"]
function_calls.append(Function(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
except Exception as e:
logger.error(f"Failed to decode tool call: {e}")

# content exclude tool calls
content = self.tool_call_regex.sub("", content)
content = self.tool_call_regex.sub("", content)
tool_calls = []
if function_calls:
tool_calls = [ToolCall(id=f"toolcall_{uuid.uuid4().hex}", function=tool_call) for tool_call in function_calls]
logger.info(f"{len(tool_calls)} tool calls extracted: {tool_calls}")

return content, tool_calls
Loading