Skip to content

Commit cbaae91

Browse files
committed
Query Embedding for Tool Rag
1 parent 4c12645 commit cbaae91

File tree

9 files changed

+169
-25
lines changed

9 files changed

+169
-25
lines changed

evaluator/algorithms/tool_rag_algorithm.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class ToolRagAlgorithm(Algorithm):
8484
- max_document_size: the maximal size, in characters, of a single indexed document, or None to disable the size limit.
8585
- indexed_tool_def_parts: the parts of the MCP tool definition to be used for index construction, such as 'name',
8686
'description', 'args', etc.
87+
You can also include 'additional_queries' (or 'examples') to append example queries for each tool if provided
88+
via the 'additional_queries' setting (see defaults below).
8789
- hybrid_mode: True to enable hybrid (sparse + dense) search and False to only enable dense search.
8890
- analyzer_params: parameters for the Milvus BM25 analyzer.
8991
- fusion_type: the algorithm for combining the dense and the sparse scores if hybrid mode is activated. Milvus only
@@ -128,7 +130,8 @@ def get_default_settings(self) -> Dict[str, Any]:
128130
"embedding_model_id": "all-MiniLM-L6-v2",
129131
"similarity_metric": "COSINE",
130132
"index_type": "FLAT",
131-
"indexed_tool_def_parts": ["name", "description"],
133+
"indexed_tool_def_parts": ["name", "description", "additional_queries"],
134+
132135

133136
# preprocessing
134137
"text_preprocessing_operations": None,
@@ -232,6 +235,14 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
232235
tags = tool.tags or []
233236
if tags:
234237
segments.append(f"tags: {' '.join(tags)}")
238+
elif p.lower() == "additional_queries":
239+
# Append example queries supplied via settings["additional_queries"][tool.name]
240+
examples_map = self._settings.get("additional_queries") or {}
241+
examples_list = examples_map.get(tool.name) or []
242+
if examples_list:
243+
rendered = self._render_examples(examples_list)
244+
if rendered:
245+
segments.append(f"ex: {rendered}")
235246

236247
if not segments:
237248
raise ValueError(f"The following tool contains none of the fields listed in indexed_tool_def_parts:\n{tool}")
@@ -249,7 +260,7 @@ def _create_docs_from_tools(self, tools: List[BaseTool]) -> List[Document]:
249260
documents.append(Document(page_content=page_content, metadata={"name": tool.name}))
250261
return documents
251262

252-
def _index_tools(self, tools: List[BaseTool]) -> None:
263+
def _index_tools(self, tools: List[BaseTool], queries: List[QuerySpecification]) -> None:
253264
self.tool_name_to_base_tool = {tool.name: tool for tool in tools}
254265

255266
self.embeddings = HuggingFaceEmbeddings(model_name=self._settings["embedding_model_id"])
@@ -308,7 +319,7 @@ def _index_tools(self, tools: List[BaseTool]) -> None:
308319
search_params=search_params,
309320
)
310321

311-
def set_up(self, model: BaseChatModel, tools: List[BaseTool]) -> None:
322+
def set_up(self, model: BaseChatModel, tools: List[BaseTool], queries: List[QuerySpecification]) -> None:
312323
super().set_up(model, tools)
313324

314325
if self._settings["cross_encoder_model_name"]:
@@ -320,7 +331,34 @@ def set_up(self, model: BaseChatModel, tools: List[BaseTool]) -> None:
320331
if self._settings["enable_query_decomposition"] or self._settings["enable_query_rewriting"]:
321332
self.query_rewriting_model = self._get_llm(self._settings["query_rewriting_model_id"])
322333

323-
self._index_tools(tools)
334+
# Build additional_queries mapping from provided QuerySpecifications so YAML is not required.
335+
try:
336+
tool_examples: Dict[str, List[str]] = {}
337+
for spec in (queries or []):
338+
add_q = getattr(spec, "additional_queries", None) or {}
339+
# Flatten wrapper {"additional_queries": {...}} if present
340+
if isinstance(add_q, dict) and "additional_queries" in add_q and len(add_q) == 1:
341+
add_q = add_q["additional_queries"]
342+
for tool_name, qmap in add_q.items():
343+
if isinstance(qmap, dict):
344+
for _, qtext in qmap.items():
345+
if isinstance(qtext, str) and qtext.strip():
346+
tool_examples.setdefault(tool_name, []).append(qtext.strip())
347+
# Dedupe while preserving order
348+
for k, v in list(tool_examples.items()):
349+
seen = set()
350+
deduped = []
351+
for s in v:
352+
if s not in seen:
353+
seen.add(s)
354+
deduped.append(s)
355+
tool_examples[k] = deduped
356+
if tool_examples:
357+
self._settings["additional_queries"] = tool_examples
358+
except Exception:
359+
pass
360+
361+
self._index_tools(tools, queries)
324362

325363
def _threshold_results(self, docs_and_scores: List[Tuple[Document, float]]) -> List[Document]:
326364
"""

evaluator/components/data_provider.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class QuerySpecification(BaseModel):
2727
"""
2828
id: int
2929
query: str
30+
additional_queries: Optional[Dict[str, Any]] = None
31+
path: Optional[str] = None
3032
reference_answer: Optional[str] = None
3133
golden_tools: ToolSet = Field(default_factory=dict)
3234
additional_tools: Optional[ToolSet] = None
@@ -313,7 +315,7 @@ def _load_queries_from_single_file(
313315
root_dataset_path: str or Path,
314316
experiment_environment: EnvironmentConfig,
315317
dataset_config: DatasetConfig,
316-
) -> List[QuerySpecification]:
318+
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
317319
with open(query_file_path, 'r') as f:
318320
data = json.load(f)
319321

@@ -332,6 +334,13 @@ def _load_queries_from_single_file(
332334
log(f"Invalid query spec, skipping this query.")
333335
else:
334336
query = raw_query_spec.get("query")
337+
if raw_query_spec.get("additional_queries"):
338+
additional_queries = raw_query_spec.get("additional_queries")
339+
print(f"Additional queries provided: {additional_queries}")
340+
341+
else:
342+
print(f"No additional queries provided")
343+
additional_queries = None
335344
query_id = int(raw_query_spec.get("query_id"))
336345
golden_tools, additional_tools = (
337346
_parse_raw_query_tool_definitions(raw_query_spec, experiment_environment, dataset_config))
@@ -345,6 +354,8 @@ def _load_queries_from_single_file(
345354
QuerySpecification(
346355
id=query_id,
347356
query=query,
357+
path=str(query_file_path),
358+
additional_queries=additional_queries,
348359
reference_answer=reference_answer,
349360
golden_tools=golden_tools,
350361
additional_tools=additional_tools or None
@@ -362,7 +373,7 @@ def get_queries(
362373
experiment_environment: EnvironmentConfig,
363374
dataset_config: DatasetConfig,
364375
fine_tuning_mode=False
365-
) -> List[QuerySpecification]:
376+
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
366377
"""Load queries from the dataset."""
367378
root_dataset_path = Path(os.getenv("ROOT_DATASET_PATH"))
368379
if not root_dataset_path:
@@ -379,14 +390,14 @@ def get_queries(
379390
queries_num = None if fine_tuning_mode else dataset_config.queries_num
380391
queries = []
381392
for path in local_paths:
393+
print(f"\n\n")
394+
print(f"--------------------------------")
395+
print(f"Loading queries from file: {path}")
396+
print(f"\n\n")
382397
remaining_queries_num = None if queries_num is None else queries_num - len(queries)
383398
if remaining_queries_num == 0:
384399
break
385-
new_queries = _load_queries_from_single_file(path,
386-
remaining_queries_num,
387-
root_dataset_path,
388-
experiment_environment,
389-
dataset_config)
400+
new_queries= _load_queries_from_single_file(path, remaining_queries_num, root_dataset_path, experiment_environment, dataset_config)
390401
queries.extend(new_queries)
391402

392403
return queries

evaluator/components/llm_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_llm(model_id: str, model_config: List[ModelConfig], **kwargs) -> BaseCha
2121

2222
log_verbose(f"Connecting to {config.provider_id} server on {config.url} serving {model_id}...")
2323
stripped_url = str(config.url).strip('/')
24+
print(f"\n \n stripped_url: {stripped_url} \n \n")
2425
if config.provider_id == ProviderId.OLLAMA:
2526
from langchain_ollama import ChatOllama
2627
return ChatOllama(model=model_id, base_url=stripped_url, **kwargs)

evaluator/components/mcp_proxy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _make_param(entry: dict, required_flag: bool) -> Parameter:
110110
parameters.append(_make_param(e, required_flag=False))
111111

112112
signature = Signature(parameters)
113+
print(f"\n \n doc_lines: {doc_lines} \n \n")
113114
docstring = "\n".join(doc_lines)
114115

115116
def tool_func(*args, **kwargs):

evaluator/config/yaml/tool_rag_experiments.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ algorithms:
6767
module_name: "tool_rag"
6868
settings:
6969
indexed_tool_def_parts: ["description"]
70+
- label: "Index With Additional Queries"
71+
module_name: "tool_rag"
72+
settings:
73+
indexed_tool_def_parts: ["additional_queries"]
7074
- label: "Index Tools By Name and Args"
7175
module_name: "tool_rag"
7276
settings:

evaluator/evaluator.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import traceback
55
from typing import List, Tuple
6-
6+
from pathlib import Path
77
import openai
88
from langgraph.errors import GraphRecursionError
99
from pydantic import ValidationError
@@ -17,6 +17,8 @@
1717
from evaluator.interfaces.algorithm import Algorithm
1818
from evaluator.utils.csv_logger import CSVLogger
1919
from evaluator.components.llm_provider import get_llm
20+
from evaluator.utils.parsing_tools import generate_and_save_additional_queries
21+
import json as _json
2022
from dotenv import load_dotenv
2123

2224
from evaluator.utils.tool_logger import ToolLogger
@@ -35,13 +37,13 @@ class Evaluator(object):
3537

3638
config: EvaluationConfig
3739

38-
def __init__(self, config_path: str | None, use_defaults: bool):
40+
def __init__(self, config_path: str | None, use_defaults: bool, test_with_additional_queries: bool = False):
3941
try:
4042
self.config = load_config(config_path, use_defaults=use_defaults)
4143
except ConfigError as ce:
4244
log(f"Configuration error: {ce}")
4345
raise SystemExit(2)
44-
46+
self.test_with_additional_queries = test_with_additional_queries
4547
async def run(self) -> None:
4648

4749
# Set up the necessary components for the experiments:
@@ -112,15 +114,13 @@ async def _run_experiment(self,
112114
Runs the specified experiment and returns the number of evaluated queries.
113115
"""
114116
processed_queries_num = 0
115-
116117
try:
117118
queries = await self._set_up_experiment(spec, metric_collectors, mcp_proxy_manager)
118119
algorithm, environment = spec
119120

120121
try:
121122
for i, query_spec in enumerate(queries):
122123
log(f"Processing query #{query_spec.id} (Experiment {exp_index} of {total_exp_num}, query {i+1} of {len(queries)})...")
123-
124124
for mc in metric_collectors:
125125
mc.prepare_for_measurement(query_spec)
126126

@@ -199,22 +199,48 @@ async def _set_up_experiment(self,
199199
log(f"Initializing LLM connection: {environment.model_id}")
200200
llm = get_llm(model_id=environment.model_id, model_config=self.config.models)
201201
log("Connection established successfully.\n")
202-
203202
log("Fetching queries for the current experiment...")
204203
queries = get_queries(environment, self.config.data)
205204
log(f"Successfully loaded {len(queries)} queries.\n")
206205
print_iterable_verbose("The following queries will be executed:\n", queries)
207-
206+
log(f"Generating additional queries.\n")
207+
generate_and_save_additional_queries(llm, queries)
208+
queries = get_queries(environment, self.config.data)
208209
log("Retrieving tool definitions for the current experiment...")
209210
tool_specs = get_tools_from_queries(queries)
210211
tools = await mcp_proxy_manager.run_mcp_proxy(tool_specs, init_client=True).get_tools()
211212
print_iterable_verbose("The following tools will be available during evaluation:\n", tools)
212213
log(f"The experiment will proceed with {len(tools)} tool(s).\n")
213214

214215
log("Setting up the algorithm and the metric collectors...")
215-
algorithm.set_up(llm, tools)
216+
217+
algorithm.set_up(llm, tools, queries)
216218
for mc in metric_collectors:
217219
mc.set_up()
218220
log("All set!\n")
219221

220222
return queries
223+
224+
if __name__ == "__main__":
225+
import argparse
226+
parser = argparse.ArgumentParser(description="Run the Evaluator experiments.")
227+
parser.add_argument("--config", type=str, default=None, help="Path to evaluation config YAML file")
228+
parser.add_argument("--defaults", action="store_true", help="Use default config options if set")
229+
parser.add_argument("--test-with-additional-queries", action="store_true", help="Test with additional queries")
230+
args = parser.parse_args()
231+
232+
from evaluator.utils.utils import log
233+
234+
log("Starting Evaluator main...")
235+
evaluator = Evaluator(
236+
config_path=args.config,
237+
use_defaults=args.defaults,
238+
test_with_additional_queries=args.test_with_additional_queries
239+
)
240+
try:
241+
import asyncio
242+
asyncio.run(evaluator.run())
243+
log("Evaluator finished successfully!")
244+
except Exception as e:
245+
log(f"Evaluator failed: {e}")
246+
raise

evaluator/metric_collectors/fac_metric_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, settings: Dict, model_config: List[ModelConfig]):
8282
super().__init__(settings, model_config)
8383

8484
# Metrics storage
85-
self.query_results = None
85+
self.query_results = []
8686

8787
# judge model configuration
8888
self.judge_model_url = os.getenv('FAC_JUDGE_MODEL_URL')

evaluator/metric_collectors/tool_selection_metric_collector.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class ToolSelectionMetricCollector(MetricCollector):
1616
def __init__(self, settings: Dict, model_config: List[ModelConfig]):
1717
super().__init__(settings, model_config)
1818

19-
self.total_queries = None
20-
self.exact_matches = None
21-
self.precision_sum = None
22-
self.recall_sum = None
19+
self.total_queries = 0
20+
self.exact_matches = 0
21+
self.precision_sum = 0.0
22+
self.recall_sum = 0.0
2323

2424
def get_collected_metrics_names(self) -> List[str]:
2525
return ["Exact Tool Selection Match Rate",
@@ -96,7 +96,10 @@ def report_results(self) -> Dict[str, Any] or None:
9696
raise RuntimeError("No measurements registered, cannot produce results.")
9797

9898
results = {
99-
"Exact Tool Selection Match Rate": self.exact_matches / self.total_queries,
99+
"Exact Tool Selection Match Rate": (
100+
(self.exact_matches or 0) / (self.total_queries or 1)
101+
if self.total_queries else 0.0
102+
),
100103
"Tool Selection Precision": self.precision_sum / self.total_queries,
101104
"Tool Selection Recall": self.recall_sum / self.total_queries,
102105
"Spurious Tool Calling Rate": 1.0 - (self.precision_sum / self.total_queries),

evaluator/utils/parsing_tools.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from evaluator.components.llm_provider import query_llm
2+
from pathlib import Path
3+
import re
4+
import json
5+
from evaluator.utils.utils import print_iterable_verbose, log
6+
7+
def generate_and_save_additional_queries(llm, queries):
8+
"""
9+
For each query in queries, use the provided LLM to generate additional_queries if not present,
10+
and save to the appropriate JSON file for that query (matching by query_id).
11+
"""
12+
13+
system_prompt = '''You create 5 additional queries for each tool and only return the additional queries information, given the query implemented, return in the following format as a JSON string:
14+
{tool_name: {"query1": "", "query2": "", "query3": "", "query4": "", "query5": ""}} '''
15+
curr_file = None
16+
for i, query_spec in enumerate(queries):
17+
# If additional_queries already present, skip generating and saving
18+
path = Path(query_spec.path)
19+
if getattr(query_spec, 'additional_queries', None) or curr_file == path:
20+
log(f"Skipping query_id {getattr(query_spec, 'id', '<N/A>')} because additional_queries is present.")
21+
continue
22+
user_prompt = f"tool_name = {getattr(query_spec, 'golden_tools', {}).keys()}, Query= {getattr(query_spec, 'query', None)}"
23+
result = query_llm(llm, system_prompt, user_prompt)
24+
# Remove markdown/code block wrappers if present
25+
additional = qwen_model_parsing(result)
26+
query_spec.additional_queries = additional
27+
# Saving additional queries to the original query JSON file
28+
if path and additional is not None:
29+
if path.exists():
30+
import json as _json
31+
with path.open('r', encoding='utf-8') as f:
32+
orig_queries = _json.load(f)
33+
for item in orig_queries:
34+
if (
35+
(item.get("query_id") == query_spec.id)
36+
or (str(item.get("query_id")) == str(query_spec.id))
37+
):
38+
item["additional_queries"] = additional
39+
with path.open('w', encoding='utf-8') as f:
40+
_json.dump(orig_queries, f, indent=2, ensure_ascii=False)
41+
log(f"Successfully added additional queries to original file {path}")
42+
curr_file = path
43+
44+
def qwen_model_parsing(response: str):
45+
"""
46+
Parse the response from the Qwen model and return the additional queries.
47+
"""
48+
# Remove markdown/code block wrappers if present
49+
match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
50+
response_text = match.group(1).strip() if match else response
51+
# Try to extract the 'additional_queries' dict block
52+
additional = None
53+
response_text = response_text.strip()
54+
try:
55+
additional = json.loads(response_text)
56+
except Exception as e:
57+
additional = None
58+
return additional
59+
60+

0 commit comments

Comments
 (0)