-
Notifications
You must be signed in to change notification settings - Fork 3
Additional queries Embedding for Tool Rag #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,6 +84,8 @@ class ToolRagAlgorithm(Algorithm): | |
| - max_document_size: the maximal size, in characters, of a single indexed document, or None to disable the size limit. | ||
| - indexed_tool_def_parts: the parts of the MCP tool definition to be used for index construction, such as 'name', | ||
| 'description', 'args', etc. | ||
| You can also include 'additional_queries' (or 'examples') to append example queries for each tool if provided | ||
| via the 'additional_queries' setting (see defaults below). | ||
| - hybrid_mode: True to enable hybrid (sparse + dense) search and False to only enable dense search. | ||
| - analyzer_params: parameters for the Milvus BM25 analyzer. | ||
| - fusion_type: the algorithm for combining the dense and the sparse scores if hybrid mode is activated. Milvus only | ||
|
|
@@ -128,7 +130,8 @@ def get_default_settings(self) -> Dict[str, Any]: | |
| "embedding_model_id": "all-MiniLM-L6-v2", | ||
| "similarity_metric": "COSINE", | ||
| "index_type": "FLAT", | ||
| "indexed_tool_def_parts": ["name", "description"], | ||
| "indexed_tool_def_parts": ["name", "description", "additional_queries"], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Until we evaluate this and see the improvement from embedding the queries, let's leave the default as it was. |
||
|
|
||
|
|
||
| # preprocessing | ||
| "text_preprocessing_operations": None, | ||
|
|
@@ -213,7 +216,6 @@ def _compose_tool_text(self, tool: BaseTool) -> str: | |
| parts_to_include = self._settings["indexed_tool_def_parts"] | ||
| if not parts_to_include: | ||
| raise ValueError("indexed_tool_def_parts must be a non-empty list") | ||
|
|
||
| segments = [] | ||
| for p in parts_to_include: | ||
| if p.lower() == "name": | ||
|
|
@@ -232,11 +234,16 @@ def _compose_tool_text(self, tool: BaseTool) -> str: | |
| tags = tool.tags or [] | ||
| if tags: | ||
| segments.append(f"tags: {' '.join(tags)}") | ||
|
|
||
| elif p.lower() == "additional_queries": | ||
| examples_map = self._settings.get("additional_queries") or {} | ||
| examples_list = examples_map.get(tool.name) or [] | ||
| if examples_list: | ||
| rendered = self._render_examples(examples_list) | ||
| if rendered: | ||
| segments.append(f"ex: {rendered}") | ||
| if not segments: | ||
| raise ValueError(f"The following tool contains none of the fields listed in indexed_tool_def_parts:\n{tool}") | ||
| text = " | ".join(segments) | ||
|
|
||
| # one-pass preprocess + truncation | ||
| text = self._preprocess_text(text) | ||
| text = self._truncate(text) | ||
|
|
@@ -249,6 +256,30 @@ def _create_docs_from_tools(self, tools: List[BaseTool]) -> List[Document]: | |
| documents.append(Document(page_content=page_content, metadata={"name": tool.name})) | ||
| return documents | ||
|
|
||
| def _collect_examples_from_tool_specs(self, tool_specs: Dict[str, Dict[str, Any]]) -> Dict[str, List[str]]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this functionality should be in this module. ToolRagAlgorithm should receive all the info in the tools list. See also my comment below. |
||
| """ | ||
| Build {tool_name: [example1, example2, ...]} from a tools dict where each | ||
| value may contain an 'additional_queries' dict mapping query keys to strings. | ||
| """ | ||
| examples: Dict[str, List[str]] = {} | ||
| for tool_name, spec in (tool_specs or {}).items(): | ||
| if not isinstance(spec, dict): | ||
| continue | ||
| aq = spec.get("additional_queries") | ||
| if isinstance(aq, dict): | ||
| for _, qtext in aq.items(): | ||
| if isinstance(qtext, str) and qtext.strip(): | ||
| examples.setdefault(tool_name, []).append(qtext.strip()) | ||
| # de-duplicate while preserving order | ||
| for k, v in list(examples.items()): | ||
| seen, out = set(), [] | ||
| for s in v: | ||
| if s not in seen: | ||
| seen.add(s) | ||
| out.append(s) | ||
| examples[k] = out | ||
| return examples | ||
|
|
||
| def _index_tools(self, tools: List[BaseTool]) -> None: | ||
| self.tool_name_to_base_tool = {tool.name: tool for tool in tools} | ||
|
|
||
|
|
@@ -308,7 +339,7 @@ def _index_tools(self, tools: List[BaseTool]) -> None: | |
| search_params=search_params, | ||
| ) | ||
|
|
||
| def set_up(self, model: BaseChatModel, tools: List[BaseTool]) -> None: | ||
| def set_up(self, model: BaseChatModel, tools: List[BaseTool], tool_specs: Any) -> None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing the signature of an interface method is a very bad practice. I understand why you did this, but it breaks the abstraction. |
||
| super().set_up(model, tools) | ||
|
|
||
| if self._settings["cross_encoder_model_name"]: | ||
|
|
@@ -320,6 +351,14 @@ def set_up(self, model: BaseChatModel, tools: List[BaseTool]) -> None: | |
| if self._settings["enable_query_decomposition"] or self._settings["enable_query_rewriting"]: | ||
| self.query_rewriting_model = self._get_llm(self._settings["query_rewriting_model_id"]) | ||
|
|
||
| # Build additional_queries mapping from provided specs (accept dict of tool specs or list of QuerySpecifications) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my remarks above - the examples should already be inside the BaseTool objects when you receive them, so this code shouldn't be here. |
||
| try: | ||
| examples_map: Dict[str, List[str]] = {} | ||
| if isinstance(tool_specs, dict): | ||
| examples_map = self._collect_examples_from_tool_specs(tool_specs) | ||
| self._settings["additional_queries"] = examples_map | ||
| except Exception: | ||
| pass | ||
| self._index_tools(tools) | ||
|
|
||
| def _threshold_results(self, docs_and_scores: List[Tuple[Document, float]]) -> List[Document]: | ||
|
|
@@ -581,4 +620,4 @@ def _dedup_keep_order(xs: List[str]) -> List[str]: | |
|
|
||
| @staticmethod | ||
| def _strip_numbering(s: str) -> str: | ||
| return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".") | ||
| return re.sub(r"^\s*(?:[-*]|\d+[).:]?)\s*", "", s).strip().rstrip(".") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ class QuerySpecification(BaseModel): | |
| """ | ||
| id: int | ||
| query: str | ||
| additional_queries: Optional[Dict[str, Any]] = None | ||
| path: Optional[str] = None | ||
|
Comment on lines
+30
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need this? Is this a leftover from some previous work? |
||
| reference_answer: Optional[str] = None | ||
| golden_tools: ToolSet = Field(default_factory=dict) | ||
| additional_tools: Optional[ToolSet] = None | ||
|
|
@@ -395,9 +397,55 @@ def get_queries( | |
| def get_tools_from_queries(queries: List[QuerySpecification]) -> ToolSet: | ||
| tools = {} | ||
|
|
||
| # Base tools from the dataset | ||
| for query_spec in queries: | ||
| tools.update(query_spec.golden_tools) | ||
| if query_spec.additional_tools: | ||
| tools.update(query_spec.additional_tools) | ||
|
|
||
| # Merge per-query additional queries from centralized store under the correct tool entry | ||
| aq = get_additional_query(query_spec.id) | ||
| if isinstance(aq, dict): | ||
| golden_tools = query_spec.golden_tools | ||
| for tool in golden_tools: | ||
| additional_queries = aq.get(tool) | ||
| tools[tool]["additional_queries"] = additional_queries | ||
|
|
||
| return tools | ||
|
|
||
|
|
||
| def load_additional_queries_store(path: str | None = None) -> List[Dict[str, Any]]: | ||
| """ | ||
| Load the centralized additional queries store. | ||
| Expected format: a JSON list of objects {"query_id": int, "additional_queries": {...}}. | ||
| Returns an empty list if the file doesn't exist or cannot be parsed. | ||
| """ | ||
| try: | ||
| store_path = Path(path) if path else (Path("data") / "additional_queries.json") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use hard-coded paths. The directory should be the root data directory (see other examples in this file). The file name should be a configuration setting. |
||
| if not store_path.exists(): | ||
| return [] | ||
| with store_path.open("r", encoding="utf-8") as f: | ||
| loaded = json.load(f) | ||
| return loaded if isinstance(loaded, list) else [] | ||
| except Exception: | ||
| return [] | ||
|
|
||
|
|
||
| def get_additional_query(query_id: int) -> Dict[str, Any] | None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it be 'get_additional_queries'? |
||
| """ | ||
| Return the additional_queries dict for the given query_id from data/additional_queries.json, | ||
| or None if not found or invalid. | ||
| """ | ||
| store = load_additional_queries_store() | ||
| for item in store: | ||
| if not isinstance(item, dict): | ||
| continue | ||
| if "query_id" not in item or "additional_queries" not in item: | ||
| continue | ||
| try: | ||
| qid = int(item["query_id"]) | ||
| except Exception: | ||
| continue | ||
| if qid == query_id and isinstance(item["additional_queries"], dict): | ||
| return item["additional_queries"] | ||
| return None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,10 @@ | ||
| import asyncio | ||
| import os | ||
| from re import S | ||
| import time | ||
| import traceback | ||
| from typing import List, Tuple | ||
|
|
||
| from pathlib import Path | ||
| import openai | ||
| from langgraph.errors import GraphRecursionError | ||
| from pydantic import ValidationError | ||
|
|
@@ -17,6 +18,8 @@ | |
| from evaluator.interfaces.algorithm import Algorithm | ||
| from evaluator.utils.csv_logger import CSVLogger | ||
| from evaluator.components.llm_provider import get_llm | ||
| from evaluator.utils.parsing_tools import generate_and_save_additional_queries | ||
| import json as _json | ||
| from dotenv import load_dotenv | ||
|
|
||
| from evaluator.utils.tool_logger import ToolLogger | ||
|
|
@@ -35,13 +38,13 @@ class Evaluator(object): | |
|
|
||
| config: EvaluationConfig | ||
|
|
||
| def __init__(self, config_path: str | None, use_defaults: bool): | ||
| def __init__(self, config_path: str | None, use_defaults: bool, test_with_additional_queries: bool = False): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this parameter. Everything should be passed to the evaluator via config only. |
||
| try: | ||
| self.config = load_config(config_path, use_defaults=use_defaults) | ||
| except ConfigError as ce: | ||
| log(f"Configuration error: {ce}") | ||
| raise SystemExit(2) | ||
|
|
||
| self.test_with_additional_queries = test_with_additional_queries | ||
| async def run(self) -> None: | ||
|
|
||
| # Set up the necessary components for the experiments: | ||
|
|
@@ -66,6 +69,18 @@ async def run(self) -> None: | |
| # Actually run the experiments | ||
| metadata_columns = ["Experiment ID", "Algorithm ID", "Algorithm Details", "Environment", "Number of Queries"] | ||
| with CSVLogger(metric_collectors, os.getenv("OUTPUT_DIR_PATH"), metadata_columns=metadata_columns) as logger: | ||
| # generate additional queries here (optional) | ||
| try: | ||
| log(f"Generating additional queries...") | ||
| environment = experiment_specs[0][1] | ||
| gen_model_id = self.config.data.additional_queries_model_id | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'additional_queries_model_id' must be added to config/schema.py and config/defaults.py |
||
| llm = get_llm(model_id=gen_model_id, model_config=self.config.models) | ||
| queries = get_queries(environment, self.config.data) | ||
| generate_and_save_additional_queries(llm, queries) | ||
| except Exception as _: | ||
| log("Skipping additional query generation due to error.") | ||
|
|
||
| # generate queries here | ||
|
Comment on lines
+72
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem here is that queries (and tools) can be different for different experiments, and here you only generate additional queries for the tools from the very first experiment.
|
||
| for i, spec in enumerate(experiment_specs): | ||
| algorithm, environment = spec | ||
| log(f"{'-' * 60}\nRunning Experiment {i+1} of {len(experiment_specs)}: {self._spec_to_str(spec)}...\n{'-' * 60}") | ||
|
|
@@ -120,7 +135,7 @@ async def _run_experiment(self, | |
| try: | ||
| for i, query_spec in enumerate(queries): | ||
| log(f"Processing query #{query_spec.id} (Experiment {exp_index} of {total_exp_num}, query {i+1} of {len(queries)})...") | ||
|
|
||
| for mc in metric_collectors: | ||
| mc.prepare_for_measurement(query_spec) | ||
|
|
||
|
|
@@ -195,26 +210,51 @@ async def _set_up_experiment(self, | |
| mcp_proxy_manager: MCPProxyManager, | ||
| ) -> List[QuerySpecification]: | ||
| algorithm, environment = spec | ||
|
|
||
| log(f"Initializing LLM connection: {environment.model_id}") | ||
| llm = get_llm(model_id=environment.model_id, model_config=self.config.models) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why you moved this line further down, but please return it here - the idea is that the following line prints "Connection established successfully" only if the LLM connection was indeed created. |
||
| log("Connection established successfully.\n") | ||
|
|
||
| log("Fetching queries for the current experiment...") | ||
| queries = get_queries(environment, self.config.data) | ||
| log(f"Successfully loaded {len(queries)} queries.\n") | ||
| print_iterable_verbose("The following queries will be executed:\n", queries) | ||
|
|
||
| llm = get_llm(model_id=environment.model_id, model_config=self.config.models) | ||
| queries = get_queries(environment, self.config.data) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a duplication of line 216. |
||
| log("Retrieving tool definitions for the current experiment...") | ||
| tool_specs = get_tools_from_queries(queries) | ||
| tools = await mcp_proxy_manager.run_mcp_proxy(tool_specs, init_client=True).get_tools() | ||
| print_iterable_verbose("The following tools will be available during evaluation:\n", tools) | ||
| log(f"The experiment will proceed with {len(tools)} tool(s).\n") | ||
|
|
||
| log("Setting up the algorithm and the metric collectors...") | ||
| algorithm.set_up(llm, tools) | ||
| # Pass queries to algorithms that accept them; fall back for others | ||
| if algorithm.__module__ == "evaluator.algorithms.tool_rag_algorithm": | ||
| algorithm.set_up(llm, tools, tool_specs) | ||
| else: | ||
| algorithm.set_up(llm, tools) | ||
| for mc in metric_collectors: | ||
| mc.set_up() | ||
| log("All set!\n") | ||
| log("Setup complete!") | ||
|
Comment on lines
-218
to
+234
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason for this change? :) |
||
|
|
||
| return queries | ||
|
|
||
| if __name__ == "__main__": | ||
| import argparse | ||
| parser = argparse.ArgumentParser(description="Run the Evaluator experiments.") | ||
| parser.add_argument("--config", type=str, default=None, help="Path to evaluation config YAML file") | ||
| parser.add_argument("--defaults", action="store_true", help="Use default config options if set") | ||
| parser.add_argument("--test-with-additional-queries", action="store_true", help="Test with additional queries") | ||
| args = parser.parse_args() | ||
|
|
||
| from evaluator.utils.utils import log | ||
|
|
||
| log("Starting Evaluator main...") | ||
| evaluator = Evaluator( | ||
| config_path=args.config, | ||
| use_defaults=args.defaults, | ||
| test_with_additional_queries=args.test_with_additional_queries | ||
| ) | ||
| try: | ||
| import asyncio | ||
| asyncio.run(evaluator.run()) | ||
| log("Evaluator finished successfully!") | ||
| except Exception as e: | ||
| log(f"Evaluator failed: {e}") | ||
| raise | ||
|
Comment on lines
+237
to
+260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this main. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just call them 'examples' everywhere. I believe it will make the code more readable and intuitive as "additional queries" is ambiguous.