Skip to content

Commit b4e532e

Browse files
committed
Address PR comments: refactor retriever-to-tool conversion
- Add abstract get_parameters() method to Retriever base class - Add convert_to_tool() instance method to Retriever class - Implement get_parameters() for all concrete retriever classes - Remove automatic query_text injection in ToolsRetriever - Update example to use new convert_to_tool() method - Remove unnecessary description from ObjectParameter in example
1 parent d7a0104 commit b4e532e

File tree

5 files changed

+906
-30
lines changed

5 files changed

+906
-30
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
Example demonstrating how to create multiple domain-specific tools from retrievers.
18+
19+
This example shows:
20+
1. How to create multiple tools from the same retriever type for different use cases
21+
2. How to provide custom parameter descriptions for each tool
22+
3. How type inference works automatically while descriptions are explicit
23+
"""
24+
25+
import neo4j
26+
from typing import cast, Any, Optional
27+
from unittest.mock import MagicMock
28+
29+
from neo4j_graphrag.retrievers.base import Retriever
30+
from neo4j_graphrag.types import RawSearchResult
31+
32+
33+
class MockVectorRetriever(Retriever):
34+
"""A mock vector retriever for demonstration purposes."""
35+
36+
VERIFY_NEO4J_VERSION = False
37+
38+
def __init__(self, driver: neo4j.Driver, index_name: str):
39+
super().__init__(driver)
40+
self.index_name = index_name
41+
42+
def get_search_results(
43+
self,
44+
query_vector: Optional[list[float]] = None,
45+
query_text: Optional[str] = None,
46+
top_k: int = 5,
47+
effective_search_ratio: int = 1,
48+
filters: Optional[dict[str, Any]] = None,
49+
) -> RawSearchResult:
50+
"""Get vector search results (mocked for demonstration)."""
51+
# Return empty results for demo
52+
return RawSearchResult(records=[], metadata={"index": self.index_name})
53+
54+
55+
def main() -> None:
56+
"""Demonstrate creating multiple domain-specific tools from retrievers."""
57+
58+
# Create mock driver (in real usage, this would be actual Neo4j driver)
59+
driver = cast(Any, MagicMock())
60+
61+
# Create retrievers for different domains using the same retriever type
62+
# In practice, these would point to different vector indexes
63+
64+
# Movie recommendations retriever
65+
movie_retriever = MockVectorRetriever(driver=driver, index_name="movie_embeddings")
66+
67+
# Product search retriever
68+
product_retriever = MockVectorRetriever(
69+
driver=driver, index_name="product_embeddings"
70+
)
71+
72+
# Document search retriever
73+
document_retriever = MockVectorRetriever(
74+
driver=driver, index_name="document_embeddings"
75+
)
76+
77+
# Convert each retriever to a domain-specific tool with custom descriptions
78+
79+
# 1. Movie recommendation tool
80+
movie_tool = movie_retriever.convert_to_tool(
81+
name="movie_search",
82+
description="Find movie recommendations based on plot, genre, or actor preferences",
83+
parameter_descriptions={
84+
"query_text": "Movie title, plot description, genre, or actor name",
85+
"query_vector": "Pre-computed embedding vector for movie search",
86+
"top_k": "Number of movie recommendations to return (1-20)",
87+
"filters": "Optional filters for genre, year, rating, etc.",
88+
"effective_search_ratio": "Search pool multiplier for better accuracy",
89+
},
90+
)
91+
92+
# 2. Product search tool
93+
product_tool = product_retriever.convert_to_tool(
94+
name="product_search",
95+
description="Search for products matching customer needs and preferences",
96+
parameter_descriptions={
97+
"query_text": "Product name, description, or customer need",
98+
"query_vector": "Pre-computed embedding for product matching",
99+
"top_k": "Maximum number of product results (1-50)",
100+
"filters": "Filters for price range, brand, category, availability",
101+
"effective_search_ratio": "Breadth vs precision trade-off for search",
102+
},
103+
)
104+
105+
# 3. Document search tool
106+
document_tool = document_retriever.convert_to_tool(
107+
name="document_search",
108+
description="Find relevant documents and knowledge articles",
109+
parameter_descriptions={
110+
"query_text": "Question, keywords, or topic to search for",
111+
"query_vector": "Semantic embedding for document retrieval",
112+
"top_k": "Number of relevant documents to retrieve (1-10)",
113+
"filters": "Document type, date range, or department filters",
114+
},
115+
)
116+
117+
# Demonstrate that each tool has distinct, meaningful descriptions
118+
tools = [movie_tool, product_tool, document_tool]
119+
120+
for tool in tools:
121+
print(f"\n=== {tool.get_name().upper()} ===")
122+
print(f"Description: {tool.get_description()}")
123+
print("Parameters:")
124+
125+
params = tool.get_parameters()
126+
for param_name, param_def in params["properties"].items():
127+
required = (
128+
"required" if param_name in params.get("required", []) else "optional"
129+
)
130+
print(
131+
f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}"
132+
)
133+
134+
# Show how the same parameter type gets different contextual descriptions
135+
print("\n=== PARAMETER COMPARISON ===")
136+
print("Same parameter 'query_text' with different contextual descriptions:")
137+
138+
for tool in tools:
139+
params = tool.get_parameters()
140+
query_text_desc = params["properties"]["query_text"]["description"]
141+
print(f" {tool.get_name()}: {query_text_desc}")
142+
143+
print("\nSame parameter 'top_k' with different contextual descriptions:")
144+
for tool in tools:
145+
params = tool.get_parameters()
146+
top_k_desc = params["properties"]["top_k"]["description"]
147+
print(f" {tool.get_name()}: {top_k_desc}")
148+
149+
150+
if __name__ == "__main__":
151+
main()

examples/retrieve/tools/retriever_to_tool_example.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
Example demonstrating how to convert a retriever to a tool.
1818
1919
This example shows:
20-
1. How to convert a custom StaticRetriever to a Tool
21-
2. How to define parameters for the tool
20+
1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method
21+
2. How to define parameters for the tool in the retriever class
2222
3. How to execute the tool
2323
"""
2424

@@ -28,11 +28,6 @@
2828

2929
from neo4j_graphrag.retrievers.base import Retriever
3030
from neo4j_graphrag.types import RawSearchResult
31-
from neo4j_graphrag.tools.tool import (
32-
StringParameter,
33-
ObjectParameter,
34-
)
35-
from neo4j_graphrag.tools.utils import convert_retriever_to_tool
3631

3732

3833
# Create a Retriever that returns static results about Neo4j
@@ -50,7 +45,15 @@ def __init__(self, driver: neo4j.Driver):
5045
def get_search_results(
5146
self, query_text: Optional[str] = None, **kwargs: Any
5247
) -> RawSearchResult:
53-
"""Return static information about Neo4j regardless of the query."""
48+
"""Return static information about Neo4j regardless of the query.
49+
50+
Args:
51+
query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information)
52+
**kwargs (Any): Additional keyword arguments (not used)
53+
54+
Returns:
55+
RawSearchResult: Static Neo4j information with metadata
56+
"""
5457
# Create formatted Neo4j information
5558
neo4j_info = (
5659
"# Neo4j Graph Database\n\n"
@@ -73,26 +76,16 @@ def get_search_results(
7376

7477

7578
def main() -> None:
76-
# Convert a StaticRetriever to a tool with specific parameters
79+
# Convert a StaticRetriever to a tool using the new convert_to_tool method
7780
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))
7881

79-
# Define parameters for the static retriever tool
80-
static_parameters = ObjectParameter(
81-
description="Parameters for the Neo4j information retriever",
82-
properties={
83-
"query_text": StringParameter(
84-
description="The query about Neo4j (any query will return general Neo4j information)",
85-
required=True,
86-
),
87-
},
88-
)
89-
90-
# Convert the retriever to a tool with specific parameters
91-
static_tool = convert_retriever_to_tool(
92-
retriever=static_retriever,
93-
description="Get general information about Neo4j graph database",
94-
parameters=static_parameters,
82+
# Convert the retriever to a tool with custom parameter descriptions
83+
static_tool = static_retriever.convert_to_tool(
9584
name="Neo4jInfoTool",
85+
description="Get general information about Neo4j graph database",
86+
parameter_descriptions={
87+
"query_text": "Any query about Neo4j (the tool returns general information regardless)"
88+
},
9689
)
9790

9891
# Print tool information
@@ -107,7 +100,7 @@ def main() -> None:
107100
# Execute the static retriever tool
108101
print("\nExecuting the static retriever tool...")
109102
static_result = static_tool.execute(
110-
query="What is Neo4j?",
103+
query_text="What is Neo4j?",
111104
)
112105
print("Static Search Results:")
113106
for i, item in enumerate(static_result):

0 commit comments

Comments
 (0)