Skip to content

Commit edb2d51

Browse files
committed
good, next tool exec
1 parent 4f90f06 commit edb2d51

File tree

1 file changed

+86
-18
lines changed

1 file changed

+86
-18
lines changed

07 Agent Choose Tool/main.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,63 @@
11
from langgraph.graph import StateGraph, END
2-
from typing import TypedDict, Literal
3-
import random
2+
from typing import TypedDict, Literal, Callable, Dict, List, Any
3+
import inspect
44
import json
5+
import random
56
from langchain_community.chat_models import ChatOllama
67
from langchain_core.prompts import PromptTemplate
78
from langchain_core.output_parsers import StrOutputParser
89
from abc import ABC, abstractmethod
910

11+
# Tool registry to hold information about tools
12+
tool_registry: List[Dict[str, Any]] = []
13+
14+
# Decorator to register tools
15+
def register_tool(func: Callable) -> Callable:
16+
signature = inspect.signature(func)
17+
docstring = func.__doc__ or ""
18+
params = [
19+
{"name": param.name, "type": param.annotation}
20+
for param in signature.parameters.values()
21+
]
22+
tool_info = {
23+
"name": func.__name__,
24+
"description": docstring,
25+
"parameters": params
26+
}
27+
tool_registry.append(tool_info)
28+
return func
29+
30+
# Define the tools with detailed parameter descriptions in the docstrings
31+
@register_tool
32+
def add(a: int, b: int) -> int:
33+
"""
34+
:function: add
35+
:param int a: First number to add
36+
:param int b: Second number to add
37+
:return: Sum of a and b
38+
"""
39+
return a + b
40+
41+
@register_tool
42+
def ls() -> List[str]:
43+
"""
44+
:function: ls
45+
:return: List of filenames in the current directory
46+
"""
47+
# Fake implementation
48+
return ["file1.txt", "file2.txt", "file3.txt"]
49+
50+
@register_tool
51+
def filewrite(name: str, content: str) -> None:
52+
"""
53+
:function: filewrite
54+
:param str name: Name of the file
55+
:param str content: Content to write to the file
56+
:return: None
57+
"""
58+
# Fake implementation
59+
print(f"Writing to {name}: {content}")
60+
1061
# Specify the local language model
1162
local_llm = "mistral"
1263
llm = ChatOllama(model=local_llm, format="json", temperature=0)
@@ -22,7 +73,7 @@ class ToolState(TypedDict):
2273
history: str
2374
use_tool: bool
2475
tool_exec: str
25-
76+
tools_list: str
2677

2778
# Define the base class for tasks
2879
class AgentBase(ABC):
@@ -41,10 +92,14 @@ def execute(self) -> ToolState:
4192
template = self.get_prompt_template()
4293
prompt = PromptTemplate.from_template(template)
4394
llm_chain = prompt | llm | StrOutputParser()
44-
generation = llm_chain.invoke({"history": self.state["history"], "use_tool": self.state["use_tool"]})
95+
generation = llm_chain.invoke({
96+
"history": self.state["history"],
97+
"use_tool": self.state["use_tool"],
98+
"tools_list": self.state["tools_list"]
99+
})
45100
data = json.loads(generation)
46-
self.state["use_tool"] = data.get("use_tool", "")
47-
self.state["tool_exec"] = data
101+
self.state["use_tool"] = data.get("use_tool", False)
102+
self.state["tool_exec"] = data.get("tool_exec", "")
48103

49104
self.state["history"] += "\n" + generation
50105
self.state["history"] = clip_history(self.state["history"])
@@ -56,8 +111,9 @@ class ChatAgent(AgentBase):
56111
def get_prompt_template(self) -> str:
57112
return """
58113
{history}
59-
As ChatAgent, decide we need use tool/py or not
60-
if we don't need tool, just reply, otherwirse, let tool agent to handle
114+
Available tools: {tools_list}
115+
As ChatAgent, decide if we need to use a tool or not.
116+
If we don't need a tool, just reply; otherwise, let the ToolAgent handle it.
61117
Output the JSON in the format: {{"scenario": "your reply", "use_tool": True/False}}
62118
"""
63119

@@ -71,12 +127,16 @@ def get_prompt_template(self) -> str:
71127
"""
72128

73129
def ToolExecute(state: ToolState) -> ToolState:
74-
choice = self.llm_output(self.state["tool_exec"])
75-
tool_name = choice["use_tool"]
76-
args = self.convert_args(tool_name, choice["args"])
130+
choice = json.loads(state["tool_exec"])
131+
tool_name = choice["function"]
132+
args = choice["args"]
77133
result = globals()[tool_name](*args)
134+
state["history"] += f"\nExecuted {tool_name} with result: {result}"
135+
state["history"] = clip_history(state["history"])
136+
state["use_tool"] = False
137+
return state
78138

79-
# for conditional edges
139+
# For conditional edges
80140
def check_use_tool(state: ToolState) -> Literal["use tool", "not use tool"]:
81141
if state.get("use_tool") == True:
82142
return "use tool"
@@ -86,7 +146,7 @@ def check_use_tool(state: ToolState) -> Literal["use tool", "not use tool"]:
86146
# Define the state machine
87147
workflow = StateGraph(ToolState)
88148

89-
# Initialize tasks for DM and Player
149+
# Initialize tasks for ChatAgent and ToolAgent
90150
def chat_agent(state: ToolState) -> ToolState:
91151
return ChatAgent(state).execute()
92152

@@ -100,7 +160,6 @@ def tool_agent(state: ToolState) -> ToolState:
100160
workflow.set_entry_point("chat_agent")
101161

102162
# Define edges between nodes
103-
104163
workflow.add_conditional_edges(
105164
"chat_agent",
106165
check_use_tool,
@@ -111,19 +170,28 @@ def tool_agent(state: ToolState) -> ToolState:
111170
)
112171

113172
workflow.add_edge('tool_agent', 'tool')
173+
workflow.add_edge('tool', 'chat_agent')
114174

115-
175+
# Generate the tools list
176+
tools_list = json.dumps([
177+
{
178+
"name": tool["name"],
179+
"description": tool["description"]
180+
}
181+
for tool in tool_registry
182+
])
116183

117184
# Compile the workflow into a runnable app
118185
app = workflow.compile()
119186

120187
# Initialize the state
121188
initial_state = ToolState(
122189
history="help me ls files in current folder",
123-
use_tool=False,
124-
)
190+
use_tool=False,
191+
tool_exec="",
192+
tools_list=tools_list
193+
)
125194

126195
for s in app.stream(initial_state):
127196
# Print the current state
128-
print("for s in app.stream(initial_state):")
129197
print(s)

0 commit comments

Comments
 (0)