Skip to content

Add ToolsRetriever class and Retriever.convert_to_tool() method #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/customize/llms/openai_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
from neo4j_graphrag.tools.tool import (
Tool,
ObjectParameter,
StringParameter,
IntegerParameter,
)

# Load environment variables from .env file (OPENAI_API_KEY required for this example)
load_dotenv()
Expand Down
7 changes: 6 additions & 1 deletion examples/customize/llms/vertexai_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from neo4j_graphrag.llm import VertexAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
from neo4j_graphrag.tools.tool import (
Tool,
ObjectParameter,
StringParameter,
IntegerParameter,
)

# Load environment variables from .env file
load_dotenv()
Expand Down
151 changes: 151 additions & 0 deletions examples/retrieve/tools/multiple_tools_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Example demonstrating how to create multiple domain-specific tools from retrievers.

This example shows:
1. How to create multiple tools from the same retriever type for different use cases
2. How to provide custom parameter descriptions for each tool
3. How type inference works automatically while descriptions are explicit
"""

import neo4j
from typing import cast, Any, Optional
from unittest.mock import MagicMock

from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RawSearchResult


class MockVectorRetriever(Retriever):
"""A mock vector retriever for demonstration purposes."""

VERIFY_NEO4J_VERSION = False

def __init__(self, driver: neo4j.Driver, index_name: str):
super().__init__(driver)
self.index_name = index_name

def get_search_results(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
effective_search_ratio: int = 1,
filters: Optional[dict[str, Any]] = None,
) -> RawSearchResult:
"""Get vector search results (mocked for demonstration)."""
# Return empty results for demo
return RawSearchResult(records=[], metadata={"index": self.index_name})


def main() -> None:
"""Demonstrate creating multiple domain-specific tools from retrievers."""

# Create mock driver (in real usage, this would be actual Neo4j driver)
driver = cast(Any, MagicMock())

# Create retrievers for different domains using the same retriever type
# In practice, these would point to different vector indexes

# Movie recommendations retriever
movie_retriever = MockVectorRetriever(driver=driver, index_name="movie_embeddings")

# Product search retriever
product_retriever = MockVectorRetriever(
driver=driver, index_name="product_embeddings"
)

# Document search retriever
document_retriever = MockVectorRetriever(
driver=driver, index_name="document_embeddings"
)

# Convert each retriever to a domain-specific tool with custom descriptions

# 1. Movie recommendation tool
movie_tool = movie_retriever.convert_to_tool(
name="movie_search",
description="Find movie recommendations based on plot, genre, or actor preferences",
parameter_descriptions={
"query_text": "Movie title, plot description, genre, or actor name",
"query_vector": "Pre-computed embedding vector for movie search",
"top_k": "Number of movie recommendations to return (1-20)",
"filters": "Optional filters for genre, year, rating, etc.",
"effective_search_ratio": "Search pool multiplier for better accuracy",
},
)

# 2. Product search tool
product_tool = product_retriever.convert_to_tool(
name="product_search",
description="Search for products matching customer needs and preferences",
parameter_descriptions={
"query_text": "Product name, description, or customer need",
"query_vector": "Pre-computed embedding for product matching",
"top_k": "Maximum number of product results (1-50)",
"filters": "Filters for price range, brand, category, availability",
"effective_search_ratio": "Breadth vs precision trade-off for search",
},
)

# 3. Document search tool
document_tool = document_retriever.convert_to_tool(
name="document_search",
description="Find relevant documents and knowledge articles",
parameter_descriptions={
"query_text": "Question, keywords, or topic to search for",
"query_vector": "Semantic embedding for document retrieval",
"top_k": "Number of relevant documents to retrieve (1-10)",
"filters": "Document type, date range, or department filters",
},
)

# Demonstrate that each tool has distinct, meaningful descriptions
tools = [movie_tool, product_tool, document_tool]

for tool in tools:
print(f"\n=== {tool.get_name().upper()} ===")
print(f"Description: {tool.get_description()}")
print("Parameters:")

params = tool.get_parameters()
for param_name, param_def in params["properties"].items():
required = (
"required" if param_name in params.get("required", []) else "optional"
)
print(
f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}"
)

# Show how the same parameter type gets different contextual descriptions
print("\n=== PARAMETER COMPARISON ===")
print("Same parameter 'query_text' with different contextual descriptions:")

for tool in tools:
params = tool.get_parameters()
query_text_desc = params["properties"]["query_text"]["description"]
print(f" {tool.get_name()}: {query_text_desc}")

print("\nSame parameter 'top_k' with different contextual descriptions:")
for tool in tools:
params = tool.get_parameters()
top_k_desc = params["properties"]["top_k"]["description"]
print(f" {tool.get_name()}: {top_k_desc}")


if __name__ == "__main__":
main()
114 changes: 114 additions & 0 deletions examples/retrieve/tools/retriever_to_tool_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Example demonstrating how to convert a retriever to a tool.

This example shows:
1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method
2. How to define parameters for the tool in the retriever class
3. How to execute the tool
"""

import neo4j
from typing import Optional, Any, cast
from unittest.mock import MagicMock

from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RawSearchResult


# Create a Retriever that returns static results about Neo4j
# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.)
class StaticRetriever(Retriever):
"""A retriever that returns static results about Neo4j."""

# Disable Neo4j version verification
VERIFY_NEO4J_VERSION = False

def __init__(self, driver: neo4j.Driver):
# Call the parent class constructor with the driver
super().__init__(driver)

def get_search_results(
self, query_text: Optional[str] = None, **kwargs: Any
) -> RawSearchResult:
"""Return static information about Neo4j regardless of the query.

Args:
query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information)
**kwargs (Any): Additional keyword arguments (not used)

Returns:
RawSearchResult: Static Neo4j information with metadata
"""
# Create formatted Neo4j information
neo4j_info = (
"# Neo4j Graph Database\n\n"
"Neo4j is a graph database management system developed by Neo4j, Inc. "
"It is an ACID-compliant transactional database with native graph storage and processing.\n\n"
"## Key Features:\n\n"
"- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n"
"- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n"
"- **ACID Compliance**: Ensures data integrity with full transaction support\n"
"- **Native Graph Storage**: Optimized storage for graph data structures\n"
"- **High Availability**: Clustering for enterprise deployments\n"
"- **Scalability**: Handles billions of nodes and relationships"
)

# Create a Neo4j record with the information
records = [neo4j.Record({"result": neo4j_info})]

# Return a RawSearchResult with the records and metadata
return RawSearchResult(records=records, metadata={"query": query_text})


def main() -> None:
# Convert a StaticRetriever to a tool using the new convert_to_tool method
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))

# Convert the retriever to a tool with custom parameter descriptions
static_tool = static_retriever.convert_to_tool(
name="Neo4jInfoTool",
description="Get general information about Neo4j graph database",
parameter_descriptions={
"query_text": "Any query about Neo4j (the tool returns general information regardless)"
},
)

# Print tool information
print("Example: StaticRetriever with specific parameters")
print(f"Tool Name: {static_tool.get_name()}")
print(f"Tool Description: {static_tool.get_description()}")
print(f"Tool Parameters: {static_tool.get_parameters()}")
print()

# Execute the tools (in a real application, this would be done by instructions from an LLM)
try:
# Execute the static retriever tool
print("\nExecuting the static retriever tool...")
static_result = static_tool.execute(
query_text="What is Neo4j?",
)
print("Static Search Results:")
for i, item in enumerate(static_result):
print(f"{i + 1}. {str(item)[:100]}...")

except Exception as e:
print(f"Error executing tool: {e}")


if __name__ == "__main__":
main()
Loading