Skip to content

Commit 87586e6

Browse files
committed
Better tool call checks
- Added a better tool call checks - Remove prompt for confirming tool runs *will return with config later*
1 parent 156f1f9 commit 87586e6

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

functions/system/ddg_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
print("No module named 'duckduckgo_search' found")
1010

1111

12-
def ddg_search(query: str, max_results: int = 5) -> str:
12+
def ddg_search(query: str, results: int = 5) -> str:
1313
if searcher is not None:
14-
return json.dumps(searcher.text(query, max_results=max_results))
14+
return json.dumps(searcher.text(query, max_results=results))
1515
else:
1616
return "Cannot load the duckduckgo search module!"
1717

@@ -26,7 +26,7 @@ def ddg_search(query: str, max_results: int = 5) -> str:
2626
"type": "object",
2727
"properties": {
2828
"query": {"type": "string", "description": "The search query to look for"},
29-
"max_results": {"type": "int", "description": "The number of results to get, defaults to 5 results"}
29+
"results": {"type": "int", "description": "The number of results to get, defaults to 5 results"}
3030
},
3131
"required": ["query"],
3232
},

tool_calling.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import glob
2+
import inspect
23
import json
34
import os
45
from datetime import datetime
@@ -53,16 +54,50 @@ def glob_import(glob_str: str) -> Tuple[List[dict], Dict]:
5354
functions.extend(user_functions)
5455

5556

57+
def print_assistant_messages(responses: list):
58+
printable = [resp["content"] for resp in responses if resp["role"] == "assistant" and resp["content"] != ""]
59+
print("\n".join(printable))
60+
61+
62+
def is_valid_func_call(fn_name: str, fn_args: dict) -> bool:
63+
# Checks if the function exists
64+
fnc = get_actual_function(fn_name)
65+
if not fnc:
66+
return False
67+
params = inspect.signature(fnc).parameters
68+
fnc_args = params.keys()
69+
model_args = fn_args.keys()
70+
71+
# Check for extra arguments we aren't expecting
72+
extra = [arg for arg in model_args if arg not in fnc_args]
73+
if len(extra) != 0:
74+
return False
75+
76+
# See if any arguments are missing
77+
missing = fnc_args - model_args
78+
# No args are missing
79+
if len(missing) == 0:
80+
# Check that the keys match
81+
return list(params.keys()) == list(fn_args.keys())
82+
# Check if the "missing" parameters have no default value
83+
for key in missing:
84+
param = params.get(key)
85+
if param.default == inspect.Parameter.empty:
86+
return False
87+
return True
88+
89+
5690
def print_func_calls(responses: list):
5791
print("\nFunctions to call (invalid functions will be ignored!): ")
58-
functions = function_library.keys()
92+
5993
for response in responses:
6094
fn_call = response.get("function_call", None)
6195
if fn_call:
6296
name = fn_call["name"]
97+
args = json.loads(fn_call["arguments"])
6398
print(f"\n{name}")
64-
print(f" Args: {fn_call["arguments"]}")
65-
print(f" Valid: {("Y" if name in functions else "N")}")
99+
print(f" Args: {args}")
100+
print(f" Valid: {("Y" if is_valid_func_call(name, args) else "N")}")
66101
print("\n")
67102

68103

@@ -85,12 +120,14 @@ def execute_functions(responses: list) -> list:
85120
fn_args = json.loads(fn_call["arguments"])
86121
fnc = get_actual_function(fn_name)
87122
if fnc:
88-
fn_res = fnc(**fn_args)
89-
90-
if fn_res:
91-
messages.append(
92-
{"role": "function", "name": fn_name, "content": fn_res}
93-
)
123+
if is_valid_func_call(fn_name, fn_args):
124+
fn_res = fnc(**fn_args)
125+
if fn_res:
126+
messages.append(
127+
{"role": "function", "name": fn_name, "content": fn_res}
128+
)
129+
else:
130+
messages.append({"role": "function", "name": fn_name, "content": "This function either does not exist, or the parameters provided were invalid."})
94131
return messages
95132

96133

@@ -124,7 +161,6 @@ def main():
124161
messages.append({"role": "user", "content": prompt})
125162

126163
print("Prompting the backend for function calls...")
127-
#print(messages[1]["content"])
128164

129165
finished = False
130166

@@ -139,33 +175,19 @@ def main():
139175

140176
# Add AI response/function call requests to context
141177
messages.extend(responses)
142-
178+
print_assistant_messages(responses)
143179
# If there are no function calls, this will break the loop after this conversation turn
144180
if not has_func_calls(responses):
145181
finished = True
146-
# print(messages[-1]["content"])
147182
else:
148183
# Print all function calls the model is requesting
149184
print_func_calls(responses)
150-
# Ask for confirmation before running any functions
151-
if confirm_input():
152-
# Execute functions and add their responses to the context
153-
func_responses = execute_functions(responses)
154-
messages.extend(func_responses)
155-
else:
156-
messages.append(
157-
{
158-
"role": "user",
159-
"content": "The user denied access to your tool call. Try to complete the task without tools if you can",
160-
}
161-
)
185+
# Execute functions and add their responses to the context
186+
func_responses = execute_functions(responses)
187+
messages.extend(func_responses)
162188

163-
# Get the AI's response after tool calls and print it
164-
for responses in llm.chat(messages=messages, functions=functions):
165-
pass
166-
messages.extend(responses)
167-
print(messages[-1]["content"])
168189
except KeyboardInterrupt:
190+
# print(json.dumps(messages, indent=2))
169191
running = False
170192
exit("Exiting...")
171193

0 commit comments

Comments
 (0)