Skip to content

Commit 2f94d57

Browse files
committed
feat: Added Tool Fetcher experiment and code.
1 parent 4c12645 commit 2f94d57

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from typing import Dict, List, Any, Optional
2+
import json
3+
import os
4+
5+
from langchain_core.language_models import BaseChatModel
6+
from langchain_core.tools import BaseTool, StructuredTool
7+
from langgraph.prebuilt import create_react_agent
8+
from langchain.docstore.document import Document
9+
from langchain_huggingface import HuggingFaceEmbeddings
10+
from langchain_milvus import Milvus
11+
12+
from evaluator.components.data_provider import QuerySpecification
13+
from evaluator.config.schema import ModelConfig
14+
from evaluator.utils.module_extractor import register_algorithm
15+
from evaluator.interfaces.algorithm import Algorithm, AlgoResponse
16+
from evaluator.utils.utils import log_verbose
17+
18+
19+
# Constants
20+
DEFAULT_EMBEDDING_MODEL = "all-MiniLM-L6-v2"
21+
DEFAULT_SEARCH_K = 8 # LLM dynamic search default
22+
DEFAULT_MAX_RESULT_CHARS = 4000
23+
DEFAULT_DROP_OLD_COLLECTION = True
24+
DEFAULT_COLLECTION_NAME = "tool_fetcher_tools_collection"
25+
MAX_EMBEDDING_TEXT_LENGTH = 2048 # Typical embedding model context limit
26+
MIN_K = 1
27+
MAX_K = 50
28+
# Substring search scoring weights
29+
NAME_MATCH_WEIGHT = 2
30+
DESC_MATCH_WEIGHT = 1
31+
32+
33+
@register_algorithm("tool_fetcher")
34+
class ToolFetcherAlgorithm(Algorithm):
35+
"""
36+
Single-tool orchestration algorithm.
37+
38+
Exposes one tool (tool_hub) that the model can call to search and fetch
39+
other tools dynamically based on a natural-language request. The fetched
40+
tools are then made available to the agent on subsequent invocations of
41+
tool_hub during the same query.
42+
"""
43+
44+
def __init__(self, settings: Dict, model_config: List[ModelConfig], label: str = None):
45+
super().__init__(settings, model_config, label)
46+
self._all_tools = None
47+
self._tool_map = None # Cache for name->tool lookup
48+
self._active_tools = None
49+
self._vector_store = None
50+
self._embeddings = None
51+
52+
def set_up(self, model: BaseChatModel, tools: List[BaseTool]) -> None:
53+
super().set_up(model, tools)
54+
self._all_tools = tools
55+
self._tool_map = {t.name: t for t in tools} # Build lookup cache once
56+
self._active_tools = []
57+
self._build_vector_index(tools)
58+
59+
def _build_vector_index(self, tools: List[BaseTool]) -> None:
60+
"""Build Milvus vector index for tool retrieval. Falls back to None on failure."""
61+
try:
62+
embedding_model_id = self._settings.get("embedding_model_id", DEFAULT_EMBEDDING_MODEL)
63+
log_verbose(f"Initializing embeddings with model: {embedding_model_id}")
64+
self._embeddings = HuggingFaceEmbeddings(model_name=embedding_model_id)
65+
66+
milvus_uri = os.getenv("MILVUS_URL") or "http://localhost:19530"
67+
collection = self._settings.get("collection_name", DEFAULT_COLLECTION_NAME)
68+
drop_old = bool(self._settings.get("drop_old_collection", DEFAULT_DROP_OLD_COLLECTION))
69+
70+
docs = [
71+
Document(
72+
page_content=f"name: {t.name or ''} | desc: {getattr(t, 'description', '') or ''}"[:MAX_EMBEDDING_TEXT_LENGTH],
73+
metadata={"name": t.name or ""}
74+
)
75+
for t in tools
76+
]
77+
78+
log_verbose(f"Building Milvus collection: {collection} (drop_old={drop_old})")
79+
self._vector_store = Milvus.from_documents(
80+
documents=docs,
81+
embedding=self._embeddings,
82+
collection_name=collection,
83+
connection_args={"uri": milvus_uri},
84+
drop_old=drop_old,
85+
index_params={"index_type": "FLAT", "metric_type": "COSINE"},
86+
search_params={"metric_type": "COSINE"},
87+
)
88+
except Exception as e:
89+
log_verbose(f"Vector store initialization failed: {e}. Falling back to substring search.")
90+
self._vector_store = None
91+
92+
def _clamp_k(self, k: Optional[int], default: int) -> int:
93+
"""Clamp k value to valid range."""
94+
try:
95+
value = int(k) if k is not None else default
96+
except (ValueError, TypeError):
97+
value = default
98+
return max(MIN_K, min(value, MAX_K))
99+
100+
def _search_tools(self, query: str, limit: int) -> List[BaseTool]:
101+
"""Search tools using vector similarity or substring matching."""
102+
if not self._all_tools:
103+
return []
104+
105+
# Try vector search first
106+
if self._vector_store is not None:
107+
try:
108+
results = self._vector_store.similarity_search_with_score(query or "", k=limit)
109+
ordered = [
110+
self._tool_map[doc.metadata["name"]]
111+
for doc, _score in results
112+
if doc.metadata.get("name") in self._tool_map
113+
]
114+
if ordered:
115+
return ordered
116+
except Exception as e:
117+
log_verbose(f"Vector search failed: {e}. Falling back to substring search.")
118+
119+
# Fallback: substring search
120+
q = (query or "").strip().lower()
121+
ranked = []
122+
for tool in self._all_tools:
123+
name = (tool.name or "").lower()
124+
desc = (getattr(tool, "description", "") or "").lower()
125+
score = (NAME_MATCH_WEIGHT if q in name else 0) + (DESC_MATCH_WEIGHT if q in desc else 0)
126+
if score > 0:
127+
ranked.append((score, tool))
128+
129+
ranked.sort(key=lambda x: x[0], reverse=True)
130+
return [t for _, t in ranked[:limit]] or self._all_tools[:limit]
131+
132+
def _handle_search(self, query: str, k: Optional[int]) -> str:
133+
"""Handle tool search action."""
134+
default_k = self._settings.get("default_search_k", DEFAULT_SEARCH_K)
135+
limit = self._clamp_k(k, default_k)
136+
matches = self._search_tools(query, limit)
137+
138+
existing_names = {t.name for t in self._active_tools}
139+
newly_added = []
140+
for t in matches:
141+
if t.name not in existing_names:
142+
self._active_tools.append(t)
143+
newly_added.append(t.name)
144+
145+
return json.dumps({
146+
"mode": "search",
147+
"fetched": newly_added,
148+
"active": [t.name for t in self._active_tools],
149+
})
150+
151+
def _handle_call(self, tool_name: str, tool_input: str) -> str:
152+
"""Handle tool invocation action."""
153+
tool = self._tool_map.get(tool_name)
154+
155+
if tool is None:
156+
return json.dumps({"mode": "call", "error": f"tool '{tool_name}' not found"})
157+
158+
# Parse input as JSON if possible
159+
try:
160+
parsed = json.loads(tool_input) if tool_input else tool_input
161+
except json.JSONDecodeError:
162+
parsed = tool_input
163+
164+
# Add to active tools
165+
if tool.name not in {t.name for t in self._active_tools}:
166+
self._active_tools.append(tool)
167+
168+
# Log tool usage
169+
self._log_tool_usage(tool.name)
170+
171+
# Invoke tool
172+
try:
173+
result = tool.invoke(parsed)
174+
result_str = json.dumps(result) if isinstance(result, (dict, list)) else str(result)
175+
except Exception as e:
176+
log_verbose(f"Tool invocation failed for {tool.name}: {e}")
177+
return json.dumps({"mode": "call", "tool": tool.name, "error": str(e)})
178+
179+
max_chars = self._settings.get("max_result_chars", DEFAULT_MAX_RESULT_CHARS)
180+
return json.dumps({
181+
"mode": "call",
182+
"tool": tool.name,
183+
"result": result_str[:max_chars],
184+
})
185+
186+
def _log_tool_usage(self, tool_name: str) -> None:
187+
"""Log tool usage to file if TOOL_LOG_PATH is set."""
188+
try:
189+
log_path = os.getenv("TOOL_LOG_PATH")
190+
if log_path:
191+
with open(log_path, "a") as f:
192+
f.write(f"[TOOL] {tool_name}\n")
193+
except Exception as e:
194+
log_verbose(f"Tool logging failed: {e}")
195+
196+
def _make_tool_hub(self) -> BaseTool:
197+
"""
198+
Create the single tool_hub tool for searching and calling other tools.
199+
200+
The returned closure captures self for accessing instance state (_all_tools,
201+
_active_tools, _settings, etc.) and delegates to _handle_search/_handle_call.
202+
"""
203+
def run(action: str = "", query: str = "", k: Optional[int] = None,
204+
tool_name: str = "", tool_input: str = "") -> str:
205+
act = (action or "").strip().lower()
206+
207+
if act in ("search", "find", "fetch") or (not act and query):
208+
return self._handle_search(query, k)
209+
210+
if act == "call" or tool_name:
211+
return self._handle_call(tool_name, tool_input)
212+
213+
return json.dumps({"error": "invalid action; use 'search' or 'call'"})
214+
215+
return StructuredTool.from_function(
216+
name="tool_hub",
217+
description=(
218+
"IMPORTANT: This is the ONLY tool you can call directly. All other tools must be accessed through tool_hub.\n\n"
219+
"To complete any task:\n"
220+
"1. FIRST search for relevant tools: action='search', query='description of what you need', k=8\n"
221+
"2. THEN call the found tools: action='call', tool_name='exact_tool_name', tool_input='{\"param\": \"value\"}'\n\n"
222+
"The search will return a list of available tools. You must then call each tool using action='call'."
223+
),
224+
func=run,
225+
)
226+
227+
async def process_query(self, query_spec: QuerySpecification) -> AlgoResponse:
228+
"""Process query using the tool hub pattern."""
229+
if self._all_tools is None:
230+
raise RuntimeError("process_query called before set_up")
231+
232+
# Reset active tools
233+
self._active_tools = []
234+
235+
# Create agent with tool_hub
236+
hub = self._make_tool_hub()
237+
agent = create_react_agent(self._model, [hub])
238+
239+
# No additional guidance - rely solely on tool_hub's built-in description
240+
# to isolate retrieval quality from prompt engineering effects
241+
response = await self._invoke_agent_on_query(agent, query_spec.query)
242+
243+
# Return tools that the agent actually retrieved during execution
244+
retrieved = [t.name for t in self._active_tools]
245+
return response, retrieved
246+
247+
def tear_down(self) -> None:
248+
"""Clean up resources."""
249+
self._all_tools = None
250+
self._tool_map = None
251+
self._active_tools = None
252+
self._vector_store = None
253+
self._embeddings = None
254+
255+
def get_default_settings(self) -> Dict[str, Any]:
256+
"""Return default configuration settings."""
257+
return {
258+
"embedding_model_id": DEFAULT_EMBEDDING_MODEL,
259+
"default_search_k": DEFAULT_SEARCH_K,
260+
"drop_old_collection": DEFAULT_DROP_OLD_COLLECTION,
261+
"collection_name": DEFAULT_COLLECTION_NAME,
262+
"max_result_chars": DEFAULT_MAX_RESULT_CHARS,
263+
}
264+
265+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
data:
2+
query_file_paths:
3+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G1_category.json"
4+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G1_instruction.json"
5+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G1_tool.json"
6+
fine_tuning_query_file_paths:
7+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G2_category.json"
8+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G2_instruction.json"
9+
- "https://raw.githubusercontent.com/THUNLP-MT/StableToolBench/refs/heads/master/solvable_queries/test_instruction/G3_instruction.json"
10+
tool_file_paths:
11+
- "https://huggingface.co/datasets/stabletoolbench/ToolEnv2404/resolve/main/toolenv2404_filtered.tar.gz"
12+
reference_answers_path: "https://huggingface.co/datasets/stabletoolbench/baselines/resolve/main/data_baselines.zip"
13+
reference_model_id: "chatgpt_cot"
14+
queries_num: 5 # Small number for quick testing; increase for full evaluation
15+
16+
models:
17+
- id: "Qwen/Qwen3-8B"
18+
url: "${QWEN_MODEL_URL}"
19+
provider_id: "vllm"
20+
21+
environments:
22+
- model_id: "Qwen/Qwen3-8B"
23+
irrelevant_tools_ratio: 0.0
24+
irrelevant_tools_from_same_categories: true
25+
26+
algorithms:
27+
- label: "Tool Fetcher"
28+
module_name: "tool_fetcher"
29+
settings:
30+
# Dense embedding model for indexing tools (same as tool_rag for fair comparison)
31+
embedding_model_id: "all-MiniLM-L6-v2"
32+
# Default k used when the LLM calls tool_hub search without specifying k
33+
default_search_k: 8
34+
# Max characters returned from tool invocations (prevents context overflow)
35+
max_result_chars: 4000
36+
# Note: drop_old_collection and collection_name omitted (using defaults)
37+
38+
metric_collectors:
39+
- module_name: "tool_selection_metric_collector"
40+
settings: {}
41+
- module_name: "tool_retrieval_metric_collector"
42+
settings:
43+
ks: [1, 3]
44+
ap_rel_threshold: 1.0
45+
- module_name: "efficiency_metric_collector"
46+
settings: {}
47+
48+

0 commit comments

Comments
 (0)