1
1
from langgraph .graph import StateGraph , END
2
- from typing import TypedDict , Literal
3
- import random
2
+ from typing import TypedDict , Literal , Callable , Dict , List , Any
3
+ import inspect
4
4
import json
5
+ import random
5
6
from langchain_community .chat_models import ChatOllama
6
7
from langchain_core .prompts import PromptTemplate
7
8
from langchain_core .output_parsers import StrOutputParser
8
9
from abc import ABC , abstractmethod
9
10
11
+ # Tool registry to hold information about tools
12
+ tool_registry : List [Dict [str , Any ]] = []
13
+
14
+ # Decorator to register tools
15
+ def register_tool (func : Callable ) -> Callable :
16
+ signature = inspect .signature (func )
17
+ docstring = func .__doc__ or ""
18
+ params = [
19
+ {"name" : param .name , "type" : param .annotation }
20
+ for param in signature .parameters .values ()
21
+ ]
22
+ tool_info = {
23
+ "name" : func .__name__ ,
24
+ "description" : docstring ,
25
+ "parameters" : params
26
+ }
27
+ tool_registry .append (tool_info )
28
+ return func
29
+
30
+ # Define the tools with detailed parameter descriptions in the docstrings
31
+ @register_tool
32
+ def add (a : int , b : int ) -> int :
33
+ """
34
+ :function: add
35
+ :param int a: First number to add
36
+ :param int b: Second number to add
37
+ :return: Sum of a and b
38
+ """
39
+ return a + b
40
+
41
+ @register_tool
42
+ def ls () -> List [str ]:
43
+ """
44
+ :function: ls
45
+ :return: List of filenames in the current directory
46
+ """
47
+ # Fake implementation
48
+ return ["file1.txt" , "file2.txt" , "file3.txt" ]
49
+
50
+ @register_tool
51
+ def filewrite (name : str , content : str ) -> None :
52
+ """
53
+ :function: filewrite
54
+ :param str name: Name of the file
55
+ :param str content: Content to write to the file
56
+ :return: None
57
+ """
58
+ # Fake implementation
59
+ print (f"Writing to { name } : { content } " )
60
+
10
61
# Specify the local language model
11
62
local_llm = "mistral"
12
63
llm = ChatOllama (model = local_llm , format = "json" , temperature = 0 )
@@ -22,7 +73,7 @@ class ToolState(TypedDict):
22
73
history : str
23
74
use_tool : bool
24
75
tool_exec : str
25
-
76
+ tools_list : str
26
77
27
78
# Define the base class for tasks
28
79
class AgentBase (ABC ):
@@ -41,10 +92,14 @@ def execute(self) -> ToolState:
41
92
template = self .get_prompt_template ()
42
93
prompt = PromptTemplate .from_template (template )
43
94
llm_chain = prompt | llm | StrOutputParser ()
44
- generation = llm_chain .invoke ({"history" : self .state ["history" ], "use_tool" : self .state ["use_tool" ]})
95
+ generation = llm_chain .invoke ({
96
+ "history" : self .state ["history" ],
97
+ "use_tool" : self .state ["use_tool" ],
98
+ "tools_list" : self .state ["tools_list" ]
99
+ })
45
100
data = json .loads (generation )
46
- self .state ["use_tool" ] = data .get ("use_tool" , "" )
47
- self .state ["tool_exec" ] = data
101
+ self .state ["use_tool" ] = data .get ("use_tool" , False )
102
+ self .state ["tool_exec" ] = data . get ( "tool_exec" , "" )
48
103
49
104
self .state ["history" ] += "\n " + generation
50
105
self .state ["history" ] = clip_history (self .state ["history" ])
@@ -56,8 +111,9 @@ class ChatAgent(AgentBase):
56
111
def get_prompt_template (self ) -> str :
57
112
return """
58
113
{history}
59
- As ChatAgent, decide we need use tool/py or not
60
- if we don't need tool, just reply, otherwirse, let tool agent to handle
114
+ Available tools: {tools_list}
115
+ As ChatAgent, decide if we need to use a tool or not.
116
+ If we don't need a tool, just reply; otherwise, let the ToolAgent handle it.
61
117
Output the JSON in the format: {{"scenario": "your reply", "use_tool": True/False}}
62
118
"""
63
119
@@ -71,12 +127,16 @@ def get_prompt_template(self) -> str:
71
127
"""
72
128
73
129
def ToolExecute (state : ToolState ) -> ToolState :
74
- choice = self . llm_output ( self . state ["tool_exec" ])
75
- tool_name = choice ["use_tool " ]
76
- args = self . convert_args ( tool_name , choice ["args" ])
130
+ choice = json . loads ( state ["tool_exec" ])
131
+ tool_name = choice ["function " ]
132
+ args = choice ["args" ]
77
133
result = globals ()[tool_name ](* args )
134
+ state ["history" ] += f"\n Executed { tool_name } with result: { result } "
135
+ state ["history" ] = clip_history (state ["history" ])
136
+ state ["use_tool" ] = False
137
+ return state
78
138
79
- # for conditional edges
139
+ # For conditional edges
80
140
def check_use_tool (state : ToolState ) -> Literal ["use tool" , "not use tool" ]:
81
141
if state .get ("use_tool" ) == True :
82
142
return "use tool"
@@ -86,7 +146,7 @@ def check_use_tool(state: ToolState) -> Literal["use tool", "not use tool"]:
86
146
# Define the state machine
87
147
workflow = StateGraph (ToolState )
88
148
89
- # Initialize tasks for DM and Player
149
+ # Initialize tasks for ChatAgent and ToolAgent
90
150
def chat_agent (state : ToolState ) -> ToolState :
91
151
return ChatAgent (state ).execute ()
92
152
@@ -100,7 +160,6 @@ def tool_agent(state: ToolState) -> ToolState:
100
160
workflow .set_entry_point ("chat_agent" )
101
161
102
162
# Define edges between nodes
103
-
104
163
workflow .add_conditional_edges (
105
164
"chat_agent" ,
106
165
check_use_tool ,
@@ -111,19 +170,28 @@ def tool_agent(state: ToolState) -> ToolState:
111
170
)
112
171
113
172
workflow .add_edge ('tool_agent' , 'tool' )
173
+ workflow .add_edge ('tool' , 'chat_agent' )
114
174
115
-
175
+ # Generate the tools list
176
+ tools_list = json .dumps ([
177
+ {
178
+ "name" : tool ["name" ],
179
+ "description" : tool ["description" ]
180
+ }
181
+ for tool in tool_registry
182
+ ])
116
183
117
184
# Compile the workflow into a runnable app
118
185
app = workflow .compile ()
119
186
120
187
# Initialize the state
121
188
initial_state = ToolState (
122
189
history = "help me ls files in current folder" ,
123
- use_tool = False ,
124
- )
190
+ use_tool = False ,
191
+ tool_exec = "" ,
192
+ tools_list = tools_list
193
+ )
125
194
126
195
for s in app .stream (initial_state ):
127
196
# Print the current state
128
- print ("for s in app.stream(initial_state):" )
129
197
print (s )
0 commit comments